Merge branch 'pull_request_4'
authorChristopher Zorn <christopher@mochimedia.com>
Mon, 2 May 2011 18:46:14 +0000 (11:46 -0700)
committerChristopher Zorn <christopher@mochimedia.com>
Mon, 2 May 2011 18:46:14 +0000 (11:46 -0700)
14 files changed:
README.txt
punjab.tac
punjab/__init__.py
punjab/error.py
punjab/httpb.py
punjab/httpb_client.py
punjab/jabber.py
punjab/patches.py [new file with mode: 0644]
punjab/session.py
punjab/xmpp/server.py
tests/test_basic.py
tests/testparser.py
tests/xep124.py
tests/xep206.py

index b6f583c..7be4795 100644 (file)
@@ -17,7 +17,7 @@ Copyright (C) 2001-2010 Christopher Zorn , tofu@thetofu.com
 GENERAL INFORMATION
 
 PunJab is a HTTP jabber client interface. It is a BOSH connection manager that
-allows persistent client connections to a XMPP server. 
+allows persistent client connections to a XMPP server.
 
 For more information about punjab see the following URL :
 
@@ -26,5 +26,6 @@ http://code.stanziq.com/punjab
 
 CONTRIBUTORS
 
-Jack Moffitt xmpp:jackm@jabber.org - Improved HTTP Binding and Polling
-Garret Heaton <powdahound@gmail.com> - Whitelist improvements and bugfixes.
+Jack Moffitt xmpp:jackm@jabber.org
+Garret Heaton <powdahound@gmail.com>
+Zewt (https://github.com/zewt)
index a313bf5..1197730 100644 (file)
@@ -11,7 +11,7 @@ root = static.File("./html")
 #bosh = HttpbService(1, use_raw=True)
 bosh = HttpbService(1)
 
-# You can limit servers with a whitelist. 
+# You can limit servers with a whitelist.
 # The whitelist is a list of strings to match domain names.
 # bosh.white_list = ['jabber.org', 'thetofu.com']
 # or a black list
index 70bf013..a562f49 100644 (file)
@@ -4,6 +4,7 @@ Punjab - multiple http interfaces to jabber.
 """
 from twisted.python import log
 from twisted.application import service
+import patches
 
 
 def uriCheck(elem, uri):
@@ -43,9 +44,8 @@ class Service(service.Service):
             log.msg('Punjab Error: ')
             log.msg(failure.printBriefTraceback())
             log.msg(body)
-        failure.raiseException()                
-        
-            
+        failure.raiseException()
+
     def success(self, result, body = None):
         """
         If success we log it and return result
@@ -60,11 +60,10 @@ def makeService(config):
     Create a punjab service to run
     """
     from twisted.web import  server, resource, static
-    from twisted.application import service, internet
+    from twisted.application import internet
 
     import httpb
 
-
     serviceCollection = PunjabService()
 
     if config['html_dir']:
@@ -102,7 +101,7 @@ def makeService(config):
         sm.setServiceParent(serviceCollection)
     else:
         sm = internet.TCPServer(int(config['port']), site)
-        
+
         sm.setServiceParent(serviceCollection)
 
     serviceCollection.httpb = b
index 8e36e12..a671c42 100644 (file)
@@ -8,9 +8,10 @@ class Error(Exception):
     children     = []
     def __init__(self,msg = None):
         Exception.__init__(self)
-        self.stanza_error = msg
-        self.punjab_error = msg
-        self.msg          = msg
+        if msg:
+            self.stanza_error = msg
+            self.punjab_error = msg
+            self.msg          = msg
         
     def __str__(self):
         return self.stanza_error
@@ -60,6 +61,7 @@ conditions = {
     'see-other-uri':   {'code': '200', 'type': 'terminate'},
     'system-shutdown': {'code': '200', 'type': 'terminate'},
     'undefined-condition':     {'code': '200', 'type': 'terminate'},
+    'item-not-found':          {'code': '200', 'type': 'terminate'},
     
 }
 
index d76b829..837f24b 100644 (file)
@@ -13,7 +13,7 @@ try:
 except ImportError:
     from twisted.xish import domish
 
-import sha, time
+import hashlib, time
 import error
 from session import make_session
 import punjab
@@ -45,7 +45,7 @@ class HttpbElementStream(domish.ExpatElementStream):
         if prefixes:
             self.prefixes.update(prefixes)
         self.prefixes.update(domish.G_PREFIXES)
-        self.prefixStack = [domish.G_PREFIXES.values()] 
+        self.prefixStack = [domish.G_PREFIXES.values()]
         self.prefixCounter = 0
 
 
@@ -65,7 +65,6 @@ class HttpbElementStream(domish.ExpatElementStream):
     def _onStartElement(self, name, attrs):
         # Generate a qname tuple from the provided name
         attr_str   = ''
-        prefix_str = ''
         defaultUri = None
         uri        = None
         qname = name.split(" ")
@@ -77,45 +76,45 @@ class HttpbElementStream(domish.ExpatElementStream):
         if self.currElem:
             defaultUri = self.currElem.defaultUri
             uri = self.currElem.uri
-            
+
         if not defaultUri and currentUri in self.defaultNsStack:
             defaultUri = self.defaultNsStack[1]
-        
+
         if defaultUri and currentUri != defaultUri:
 
             raw_xml = u"""<%s xmlns='%s'%s""" % (qname[1], qname[0], '%s')
-        
+
         else:
             raw_xml = u"""<%s%s""" % (qname[1], '%s')
 
 
         # Process attributes
-        
+
         for k, v in attrs.items():
             if k.find(" ") != -1:
-                aqname = k.split(" ")                
+                aqname = k.split(" ")
                 attrs[(aqname[0], aqname[1])] = v
-                
+
                 attr_prefix = self.getPrefix(aqname[0])
                 if not self.prefixInScope(attr_prefix):
-                    attr_str = attr_str + " xmlns:%s='%s'" % (attr_prefix, 
+                    attr_str = attr_str + " xmlns:%s='%s'" % (attr_prefix,
                                                               aqname[0])
                     self.prefixStack[-1].append(attr_prefix)
-                attr_str = attr_str + " %s:%s='%s'" % (attr_prefix, 
+                attr_str = attr_str + " %s:%s='%s'" % (attr_prefix,
                                                        aqname[1],
-                                                       domish.escapeToXml(v, 
+                                                       domish.escapeToXml(v,
                                                                           True))
                 del attrs[k]
             else:
                 v = domish.escapeToXml(v, True)
-                attr_str = attr_str + " " + k + "='" + v + "'" 
+                attr_str = attr_str + " " + k + "='" + v + "'"
 
         raw_xml = raw_xml % (attr_str,)
-        
+
         # Construct the new element
         e = domish.Element(qname, self.defaultNsStack[-1], attrs, self.localPrefixes)
         self.localPrefixes = {}
-        
+
         # Document already started
         if self.documentStarted == 1:
             if self.currElem != None:
@@ -125,7 +124,7 @@ class HttpbElementStream(domish.ExpatElementStream):
 
                 self.currElem.children.append(e)
                 e.parent = self.currElem
-            
+
             self.currRawElem = self.currRawElem + raw_xml
             self.currElem = e
         # New document
@@ -138,7 +137,7 @@ class HttpbElementStream(domish.ExpatElementStream):
         # Check for null current elem; end of doc
         if self.currElem is None:
             self.DocumentEndEvent()
-            
+
         # Check for parent that is None; that's
         # the top of the stack
         elif self.currElem.parent is None:
@@ -166,7 +165,7 @@ class HttpbElementStream(domish.ExpatElementStream):
             else:
                 self.currRawElem = self.currRawElem  + domish.escapeToXml(data)
                 #self.currRawElem = self.currRawElem  + data
-            
+
             self.currElem.addContent(data)
 
     def _onStartNamespace(self, prefix, uri):
@@ -197,7 +196,7 @@ def elementStream():
         return es
 
 # make httpb body class, similar to xmlrpclib
-# 
+#
 class HttpbParse:
     """
     An xml parser for parsing the body elements.
@@ -215,10 +214,10 @@ class HttpbParse:
         Parse incoming xml and return the body and its children in a list
         """
         self.stream.parse(buf)
-        
+
         # return the doc element and its children in a list
-        return self.body, self.xmpp_elements 
-    
+        return self.body, self.xmpp_elements
+
     def serialize(self, obj):
         """
         Turn object into a string type
@@ -230,7 +229,7 @@ class HttpbParse:
     def onDocumentStart(self, rootelem):
         """
         The body document has started.
-        
+
         This should be a body.
         """
         if rootelem.name == 'body':
@@ -262,7 +261,7 @@ class HttpbParse:
         self.stream.DocumentEndEvent = self.onDocumentEnd
         self.body = ""
         self.xmpp_elements = []
-        
+
 
     def onDocumentEnd(self):
         """
@@ -299,9 +298,9 @@ class IHttpbService(Interface):
 
     def getXmppElements(self, body, session):
         """ """
-        
 
-        
+
+
 class IHttpbFactory(Interface):
     """
     Factory class for generating binding sessions.
@@ -318,8 +317,8 @@ class IHttpbFactory(Interface):
     def buildProtocol(self, addr):
         """Return a protocol """
 
-        
-    
+
+
 class Httpb(resource.Resource):
     """
     Http resource to handle BOSH requests.
@@ -343,7 +342,7 @@ class Httpb(resource.Resource):
         request.setHeader('Access-Control-Allow-Headers', 'Content-Type')
         request.setHeader('Access-Control-Max-Age', '86400')
         return ""
-                
+
     def render_GET(self, request):
         """
         GET is not used, print docs.
@@ -368,11 +367,11 @@ class Httpb(resource.Resource):
             log.msg(request.received_headers)
             log.msg("HTTPB POST : ")
             log.msg(str(request.content.read()))
-            request.content.seek(0, 0)       
+            request.content.seek(0, 0)
 
         self.hp       = HttpbParse()
         try:
-            body_tag, xmpp_elements = self.hp.parse(request.content.read()) 
+            body_tag, xmpp_elements = self.hp.parse(request.content.read())
             self.hp._reset()
 
             if getattr(body_tag, 'name', '') != "body":
@@ -384,13 +383,13 @@ class Httpb(resource.Resource):
             log.msg('ERROR: Xml Parse Error')
             log.err()
             self.hp._reset()
-            self.send_http_error(400, request) 
+            self.send_http_error(400, request)
             return server.NOT_DONE_YET
         except:
             log.err()
             # reset parser, just in case
             self.hp._reset()
-            self.send_http_error(400, request) 
+            self.send_http_error(400, request)
             return server.NOT_DONE_YET
         else:
             if self.service.inSession(body_tag):
@@ -399,7 +398,7 @@ class Httpb(resource.Resource):
                     request.rid = body_tag['rid']
                     if self.service.v:
                         log.msg(request.rid)
-                
+
                 s, d = self.service.parseBody(body_tag, xmpp_elements)
                 d.addCallback(self.return_httpb, s, request)
             elif body_tag.hasAttribute('sid'):
@@ -412,11 +411,11 @@ class Httpb(resource.Resource):
                 # start session
                 s, d = self.service.startSession(body_tag, xmpp_elements)
                 d.addCallback(self.return_session, s, request)
-                
+
             # Add an error back for returned errors
             d.addErrback(self.return_error, request)
         return server.NOT_DONE_YET
-        
+
 
     def return_session(self, data, session, request):
         # create body
@@ -424,8 +423,8 @@ class Httpb(resource.Resource):
             self.send_http_error(200, request, 'remote-connection-failed',
                                  'terminate')
             return server.NOT_DONE_YET
-        
-        b = domish.Element((NS_BIND, "body"))       
+
+        b = domish.Element((NS_BIND, "body"))
         # if we don't have an authid, we have to fail
         if session.authid != 0:
             b['authid'] = session.authid
@@ -433,20 +432,20 @@ class Httpb(resource.Resource):
             self.send_http_error(500, request, 'internal-server-error',
                                  'terminate')
             return server.NOT_DONE_YET
-        
+
         b['sid']  = session.sid
         b['wait'] = str(session.wait)
         if session.secure == 0:
             b['secure'] = 'false'
         else:
             b['secure'] = 'true'
-            
+
         b['inactivity'] = str(session.inactivity)
         ##b['polling'] = '15' # TODO: make this configurable
-        b['polling'] = str(self.polling) 
+        b['polling'] = str(self.polling)
         b['requests'] = str(session.hold + 1)
         b['window'] = str(session.window)
-        
+
         punjab.uriCheck(b, NS_BIND)
         if session.attrs.has_key('content'):
             b['content'] = session.attrs['content']
@@ -462,7 +461,7 @@ class Httpb(resource.Resource):
         self.return_body(request, b)
 
     def return_httpb(self, data, session, request):
-        # create body                
+        # create body
         b = domish.Element((NS_BIND, "body"))
         punjab.uriCheck(b, NS_BIND)
         session.touch()
@@ -470,26 +469,26 @@ class Httpb(resource.Resource):
             b['type']      = 'terminate'
         if data:
             b.children += data
-        
-        self.return_body(request, b, session.charset)        
 
-    
+        self.return_body(request, b, session.charset)
+
+
     def return_error(self, e, request):
         echildren = []
-        
+
         try:
             # TODO - clean this up and make errors better
             if getattr(e.value,'stanza_error',None):
                 ec = getattr(e.value, 'children', None)
                 if ec:
                     echildren = ec
-                    
+
                 self.send_http_error(error.conditions[str(e.value.stanza_error)]['code'],
                                      request,
                                      condition = str(e.value.stanza_error),
                                      typ = error.conditions[str(e.value.stanza_error)]['type'],
                                      children=echildren)
-                
+
                 return  server.NOT_DONE_YET
             elif e.value:
                 self.send_http_error(error.conditions[str(e.value)]['code'],
@@ -503,11 +502,11 @@ class Httpb(resource.Resource):
             log.err()
             pass
 
-    
+
     def return_body(self, request, b, charset="utf-8"):
         request.setResponseCode(200)
         bxml = b.toXml(prefixes=ns.XMPP_PREFIXES.copy()).encode(charset,'replace')
-        
+
         request.setHeader('content-type', 'text/xml')
         request.setHeader('content-length', len(bxml))
         if self.service.v:
@@ -517,11 +516,11 @@ class Httpb(resource.Resource):
                 log.msg(request.rid)
         request.write(bxml)
         request.finish()
-            
+
     def send_http_error(self, code, request, condition = 'undefined-condition', typ = 'terminate', data = '', charset = 'utf-8', children=None):
         request.setResponseCode(int(code))
         xml_prefixes = ns.XMPP_PREFIXES.copy()
-        
+
         b = domish.Element((NS_BIND, "body"))
         if condition:
             b['condition'] = str(condition)
@@ -540,7 +539,7 @@ class Httpb(resource.Resource):
 
         if self.service.v:
             log.msg('HTTPB Error %d' %(int(code),))
-        
+
         if int(code) != 400 and int(code) != 404 and int(code) != 403:
             if data != '':
                 if condition == 'see-other-uri':
@@ -548,12 +547,12 @@ class Httpb(resource.Resource):
                 else:
                     t = b.addElement('text', content = str(data))
                     t['xmlns'] = 'urn:ietf:params:xml:ns:xmpp-streams'
-                    
+
             bxml = b.toXml(prefixes=xml_prefixes).encode(charset, 'replace')
             if self.service.v:
                 log.msg('HTTPB Return Error: ' + str(code) + ' -> ' + bxml)
             request.setHeader("content-type", "text/xml")
-            request.setHeader("content-length", len(bxml))    
+            request.setHeader("content-length", len(bxml))
             request.write(bxml)
         else:
             request.setHeader("content-length", "0")
@@ -570,9 +569,9 @@ class HttpbService(punjab.Service):
     white_list = []
     black_list = []
 
-    def __init__(self, 
-                 verbose = 0, polling = 15, 
-                 use_raw = False, bindAddress=("0.0.0.0", 0), 
+    def __init__(self,
+                 verbose = 0, polling = 15,
+                 use_raw = False, bindAddress=("0.0.0.0", 0),
                  session_creator = None):
         if session_creator is not None:
             self.make_session = session_creator
@@ -580,7 +579,6 @@ class HttpbService(punjab.Service):
             self.make_session = make_session
         self.v  = verbose
         self.sessions = {}
-        self.counter  = 0
         self.polling = polling
         # self.expired  = {}
         self.use_raw  = use_raw
@@ -603,16 +601,16 @@ class HttpbService(punjab.Service):
                     if time_now - wr.wait_start >= wr.timeout:
                         wr.delayedcall(wr.deferred)
 
-            
+
     def startSession(self, body, xmpp_elements):
         """ Start a punjab jabber session """
-    
+
         # look for rid
         if not body.hasAttribute('rid') or body['rid']=='':
             if self.v:
                 log.msg('start session called but we had a rid')
             return None, defer.fail(error.NotFound)
-                
+
         # look for to
         if not body.hasAttribute('to') or body['to']=='':
             return None, defer.fail(error.BadRequest)
@@ -658,7 +656,7 @@ class HttpbService(punjab.Service):
         # look for wait
         if not body.hasAttribute('wait') or body['wait']=='':
             body['wait'] = 3
-                
+
         # look for lang
         lang = None
         if not body.hasAttribute("xml:lang") or body['xml:lang']=='':
@@ -669,8 +667,7 @@ class HttpbService(punjab.Service):
         if lang:
             body['lang'] = lang
         if not body.hasAttribute('inactivity'):
-            body['inactivity'] = 60 
-        
+            body['inactivity'] = 60
         return self.make_session(self, body.attributes)
 
     def stopService(self):
@@ -689,7 +686,7 @@ class HttpbService(punjab.Service):
 
     def parseBody(self, body, xmpp_elements):
         try:
-            # grab session                    
+            # grab session
             if body.hasAttribute('sid'):
                 sid = str(body['sid'])
             else:
@@ -703,43 +700,30 @@ class HttpbService(punjab.Service):
                 if self.v:
                     log.msg('session does not exist?')
                 return None, defer.fail(error.NotFound)
-            ##  XXX this seems to break xmpp:restart='true'  --vargas
-            ##  (cf. http://www.xmpp.org/extensions/xep-0206.html#preconditions-sasl [Example 10])
-##            if body.hasAttribute('to') and body['to']!='':
-##                return s, defer.fail(error.BadRequest)
-            
-            # check for keys
-            # TODO - clean this up
-            foundNewKey = False
-            
+
+            if bool(s.key) != body.hasAttribute('key'):
+                # This session is keyed, but there's no key in this packet; or there's
+                # a key in this packet, but the session isn't keyed.
+                return s, defer.fail(error.Error('item-not-found'))
+
+            # If this session is keyed, validate the next key.
+            if s.key:
+                key = hashlib.sha1(body['key']).hexdigest()
+                next_key = body['key']
+                if key != s.key:
+                    if self.v:
+                        log.msg('Error in key')
+                    return s, defer.fail(error.Error('item-not-found'))
+                s.key = next_key
+
+            # If there's a newkey in this packet, save it.  Do this after validating the
+            # previous key.
             if body.hasAttribute('newkey'):
-                newkey = body['newkey']
-                s.key = newkey
-                foundNewKey = True
-            try:
-                if body.hasAttribute('key') and not foundNewKey:
-                    if s.key is not None:
-                        nk = sha.new(body['key'])
-                        key = nk.hexdigest()
-                        next_key = body['key']
-                        if key == s.key:
-                            s.key = next_key
-                        else:
-                            if self.v:
-                                log.msg('Error in key')
-                            return s, defer.fail(error.NotFound)                        
-                    else:
-                        log.err()
-                        raise s, defer.fail(error.NotFound)
-                        
-            except:
-                log.msg('HTTPB ERROR: ')
-                log.err()
-                return s, defer.fail(error.NotFound)
-            
-        
+                s.key = body['newkey']
+
+
             # need to check if this is a valid rid (within tolerance)
-            if body.hasAttribute('rid') and body['rid']!='': 
+            if body.hasAttribute('rid') and body['rid']!='':
                 if s.cache_data.has_key(int(body['rid'])):
                     s.touch()
                     # implements issue 32 and returns the data returned on a dropped connection
@@ -752,45 +736,41 @@ class HttpbService(punjab.Service):
                 if self.v:
                     log.msg('There is no rid on this request')
                 return  s, defer.fail(error.NotFound)
-            
+
             return s, self._parse(s, body, xmpp_elements)
-            
-            
+
         except:
             log.err()
             return  s, defer.fail(error.InternalServerError)
 
-            
+
     def onExpire(self, session_id):
         """ preform actions based on when the jabber connection expires """
         if self.v:
             log.msg('expire (%s)' % (str(session_id),))
             log.msg(len(self.sessions.keys()))
-        
+
     def _parse(self, session, body_tag, xmpp_elements):
         # increment the request counter
         session.rid  = session.rid + 1
-        dont_poll = False
-        d = None
-        
+
         if getattr(session, 'stream_error', None) != None:
-            # set up waiting request
-            d = defer.Deferred()            
+            # The server previously sent us a stream:error, and has probably closed
+            # the connection by now.  Forward the error to the client and terminate
+            # the session.
+            d = defer.Deferred()
             d.errback(session.stream_error)
             session.elems = []
             session.terminate()
+            return d
 
-            dont_poll = True
-        else:
-            # send all the elements
-            for el in xmpp_elements:
-                if not isinstance(el, domish.Element):
-                    session.sendRawXml(el)
-                    continue
-            
+        # Send received elements from the client to the server.  Do this even for
+        # type='terminate'.
+        for el in xmpp_elements:
+            if isinstance(el, domish.Element):
                 # something is wrong here, need to figure out what
                 # the xmlns will be lost if this is not done
-                # punjab.uriCheck(el,NS_BIND)              
+                # punjab.uriCheck(el,NS_BIND)
                 # if el.uri and el.uri != NS_BIND:
                 #    el['xmlns'] = el.uri
                 # TODO - get rid of this when we stop supporting old versions
@@ -799,18 +779,16 @@ class HttpbService(punjab.Service):
                     el.uri = None
                 if el.defaultUri == NS_BIND:
                     el.defaultUri = None
-                    
-                session.sendRawXml(el)
+
+            session.sendRawXml(el)
 
         if body_tag.hasAttribute('type') and \
            body_tag['type'] == 'terminate':
-            d = session.terminate()
-        elif not dont_poll:
-            # normal request
-            d = session.poll(d, rid = int(body_tag['rid']))
-            
-        return d
-        
+            return session.terminate()
+
+        # normal request
+        return session.poll(None, rid = int(body_tag['rid']))
+
     def _returnIq(self, cur_session, d, iq):
         """
         A callback from auth iqs
@@ -821,15 +799,14 @@ class HttpbService(punjab.Service):
         """
         A callback from auth iqs
         """
-        
+
         # session.elems.append(iq)
         return cur_session.poll(d)
-        
-        
+
     def inSession(self, body):
         """ """
         if body.hasAttribute('sid'):
-            if self.sessions.has_key(body['sid']):        
+            if self.sessions.has_key(body['sid']):
                 return True
         return False
 
@@ -839,18 +816,18 @@ class HttpbService(punjab.Service):
         """
         for i, obj in enumerate(session.msgs):
             m = session.msgs.pop(0)
-            b.addChild(m)            
+            b.addChild(m)
         for i, obj in enumerate(session.prs):
             p = session.prs.pop(0)
-            b.addChild(p)            
+            b.addChild(p)
         for i, obj in enumerate(session.iqs):
             iq = session.iqs.pop(0)
             b.addChild(iq)
-        
+
         return b
 
     def endSession(self, cur_session):
         """ end a punjab jabber session """
         d = cur_session.terminate()
         return d
-                
+
index 48a6970..d1708db 100644 (file)
@@ -1,5 +1,5 @@
 from twisted.internet import defer, protocol, reactor, stdio
-from twisted.python import log, reflect
+from twisted.python import log, reflect, failure
 try:
     from twisted.words.xish import domish, utility
 except:
@@ -10,7 +10,7 @@ from twisted.words.protocols.jabber import xmlstream, client, jid
 
 from twisted.protocols import basic
 import urlparse
-import random, binascii, base64, md5, sha, time, os, random
+import random, binascii, base64, hashlib, time, os, random
 
 import os,sys
 
@@ -52,6 +52,19 @@ class NotImplemented(Error):
     pass
 
 
+
+# Exceptions raised by the client.
+class HTTPBException(Exception): pass
+class HTTPBNetworkTerminated(HTTPBException):
+    def __init__(self, body_tag, elements):
+        self.body_tag = body_tag
+        self.elements = elements
+
+    def __str__(self):
+        return self.body_tag.toXml()
+
+
+
 class XMPPAuthenticator(client.XMPPAuthenticator):
     """
     Authenticate against an xmpp server using BOSH
@@ -157,10 +170,11 @@ class QueryFactory(protocol.ClientFactory):
             raise
         else:
             if body_tag.hasAttribute('type') and body_tag['type'] == 'terminate':
+                error = failure.Failure(HTTPBNetworkTerminated(body_tag, elements))
                 if self.deferred.called:
-                    return defer.fail((body_tag,elements))
+                    return defer.fail(error)
                 else:            
-                    self.deferred.errback((body_tag,elements))
+                    self.deferred.errback(error)
                 return
             if self.deferred.called:
                 return defer.succeed((body_tag,elements))
@@ -190,40 +204,38 @@ class QueryFactory(protocol.ClientFactory):
             
 
 
-import random, sha, md5
 
 class Keys:
-    """ A class to generate keys for http binding """
+    """Generate keys according to XEP-0124 #15 "Protecting Insecure Sessions"."""
     def __init__(self):
-        self.set_keys()
-        
-        
-    def set_keys(self):
-        seed = random.randint(30,1000000)
-        self.num_keys = random.randint(55,255)
         self.k = []
-        self.k.append(seed)
-        for i in range(self.num_keys-1):
-            x = i + 1
-            self.k.append(sha.new(str(self.k[x-1])).hexdigest())
+        
+    def _set_keys(self):
+        seed = os.urandom(1024)
+        num_keys = random.randint(55,255)
+        self.k = [hashlib.sha1(seed).hexdigest()]
+        for i in xrange(num_keys-1):
+            self.k.append(hashlib.sha1(self.k[-1]).hexdigest())
 
-        self.key_index = self.num_keys - 1
-    
     def getKey(self):
-        self.key_index = self.key_index - 1
-        return self.k.pop(self.key_index)
+        """
+        Return (key, newkey), where key is the next key to use and newkey is the next
+        newkey value to use.  If key or newkey are None, the next request doesn't require
+        that value.
+        """
+        if not self.k:
+            # This is the first call, so generate keys and only return new_key.
+            self._set_keys()
+            return None, self.k.pop()
 
-    def firstKey(self):
-        if self.key_index == self.num_keys - 1:
-            return 1
-        else:
-            return 0
+        key = self.k.pop()
 
-    def lastKey(self):
-        if self.key_index == 0:
-            return 1
-        else:
-            return 0
+        if not self.k:
+            # We're out of keys.  Regenerate keys and re-key.
+            self._set_keys()
+            return key, self.k.pop()
+
+        return key, None
 
 
 class Proxy:
@@ -365,14 +377,12 @@ class HTTPBindingStream(xmlstream.XmlStream):
             
 
     def key(self,b):
-        if self.keys.lastKey():
-            self.keys.setKeys()
-        
-        if self.keys.firstKey():
-            b['newkey'] = self.keys.getKey()
-        else:
-            b['key'] = self.keys.getKey()
-        return b
+        key, newkey = self.keys.getKey()
+
+        if key:
+            b['key'] = key
+        if newkey:
+            b['newkey'] = newkey
 
     def _cbSend(self, result):
         body, elements = result
index 0147529..2c68a1b 100644 (file)
@@ -1,4 +1,4 @@
-# punjab's jabber client 
+# punjab's jabber client
 from twisted.internet import reactor, error
 from twisted.words.protocols.jabber import client, jid
 from twisted.python import log
@@ -41,7 +41,7 @@ class JabberClientFactory(xmlstream.XmlStreamFactory):
         """
         p = self.authenticator = PunjabAuthenticator(host)
         xmlstream.XmlStreamFactory.__init__(self, p)
-        
+
         self.pending = {}
         self.maxRetries = 2
         self.host = host
@@ -61,7 +61,6 @@ class JabberClientFactory(xmlstream.XmlStreamFactory):
             if self.maxRetries and (self.retries > self.maxRetries):
                 if d:
                     d.errback(reason)
-                
 
 
     def rawDataIn(self, buf):
@@ -71,7 +70,6 @@ class JabberClientFactory(xmlstream.XmlStreamFactory):
     def rawDataOut(self, buf):
         log.msg("SEND: %s" % unicode(buf, 'utf-8').encode('ascii', 'replace'))
 
-        
 
 class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
     namespace = "jabber:client"
@@ -88,9 +86,8 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
             self.xmlstream.otherEntity = jid.internJID(self.otherHost)
         self.xmlstream.prefixes = deepcopy(XMPP_PREFIXES)
         self.xmlstream.sendHeader()
-                
+
     def streamStarted(self, rootelem = None):
-        
         if hasNewTwisted: # This is here for backwards compatibility
             xmlstream.ConnectAuthenticator.streamStarted(self, rootelem)
         else:
@@ -98,7 +95,7 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
         if rootelem is None:
             self.xversion = 3
             return
-        
+
         self.xversion = 0
         if rootelem.hasAttribute('version'):
             self.version = rootelem['version']
@@ -106,7 +103,6 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
             self.version = 0.0
 
     def associateWithStream(self, xs):
-        
         xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
 
         inits = [ (xmlstream.TLSInitiatingInitializer, False),
@@ -118,7 +114,6 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
             init.required = required
             xs.initializers.append(init)
 
-        
     def _reset(self):
         # need this to be in xmlstream
         self.xmlstream.stream = domish.elementStream()
@@ -127,13 +122,12 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
         self.xmlstream.stream.DocumentEndEvent = self.xmlstream.onDocumentEnd
         self.xmlstream.prefixes = deepcopy(XMPP_PREFIXES)
         # Generate stream header
-        
+
         if self.version != 0.0:
             sh = "<stream:stream xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams' version='%s' to='%s'>" % \
                  (self.namespace,self.version, self.streamHost.encode('utf-8'))
 
             self.xmlstream.send(str(sh))
-                                                                                                                
 
     def sendAuth(self, jid, passwd, callback, errback = None):
         self.jid    = jid
@@ -142,29 +136,28 @@ class PunjabAuthenticator(xmlstream.ConnectAuthenticator):
             self.xmlstream.addObserver(INVALID_USER_EVENT,errback)
             self.xmlstream.addObserver(AUTH_FAILED_EVENT,errback)
         if self.version != '1.0':
-            
             iq = client.IQ(self.xmlstream, "get")
             iq.addElement(("jabber:iq:auth", "query"))
             iq.query.addElement("username", content = jid.user)
             iq.addCallback(callback)
             iq.send()
 
-    
-    def authQueryResultEvent(self, iq, callback):        
+
+    def authQueryResultEvent(self, iq, callback):
         if iq["type"] == "result":
             # Construct auth request
             iq = client.IQ(self.xmlstream, "set")
             iq.addElement(("jabber:iq:auth", "query"))
             iq.query.addElement("username", content = self.jid.user)
             iq.query.addElement("resource", content = self.jid.resource)
-            
+
             # Prefer digest over plaintext
             if client.DigestAuthQry.matches(iq):
                 digest = xmlstream.hashPassword(self.xmlstream.sid, self.passwd)
                 iq.query.addElement("digest", content = digest)
             else:
                 iq.query.addElement("password", content = self.passwd)
-                
+
             iq.addCallback(callback)
             iq.send()
         else:
diff --git a/punjab/patches.py b/punjab/patches.py
new file mode 100644 (file)
index 0000000..94520d3
--- /dev/null
@@ -0,0 +1,24 @@
+# XXX: All monkey patches should be sent upstream and eventually removed.
+
+import functools
+
+def patch(cls, attr):
+    """Patch the function named attr in the object cls with the decorated function."""
+    orig_func = getattr(cls, attr)
+    @functools.wraps(orig_func)
+    def decorator(func):
+        def wrapped_func(*args, **kwargs):
+            return func(orig_func, *args, **kwargs)
+        setattr(cls, attr, wrapped_func)
+        return orig_func
+    return decorator
+
+# Modify jabber.error.exceptionFromStreamError to include the XML element in
+# the exception.
+from twisted.words.protocols.jabber import error as jabber_error
+@patch(jabber_error, "exceptionFromStreamError")
+def exceptionFromStreamError(orig, element):
+    exception = orig(element)
+    exception.element = element
+    return exception
+
index de49f20..c080dd3 100644 (file)
@@ -9,13 +9,13 @@ from twisted.names.srvconnect import SRVConnector
 
 try:
     from twisted.words.xish import domish, xmlstream
+    from twisted.words.protocols import jabber as jabber_protocol
 except ImportError:
     from twisted.xish import domish, xmlstream
 
 
 import traceback
-import random
-import md5
+import os
 from punjab import jabber
 from punjab.xmpp import ns
 
@@ -25,6 +25,10 @@ import error
 try:
     from twisted.internet import ssl
 except ImportError:
+    ssl = None
+if ssl and not ssl.supported:
+    ssl = None
+if not ssl:
     log.msg("SSL ERROR: You do not have ssl support this may cause problems with tls client connections.")
 
 
@@ -44,49 +48,25 @@ class XMPPClientConnector(SRVConnector):
         """
         host, port = SRVConnector.pickServer(self)
 
-        if not self.servers and not self.orderedServers:
-            # no SRV record, fall back..
-            port = 5222
-        if port == 5223 and xmlstream.ssl:
-            context = xmlstream.ssl.ClientContextFactory()
-            context.method = xmlstream.ssl.SSL.SSLv23_METHOD
-            
-            self.connectFunc = 'connectSSL'
-            self.connectFuncArgs = (context)
+        if port == 5223 and ssl:
+            context = ssl.ClientContextFactory()
+            context.method = ssl.SSL.SSLv23_METHOD
+
+            self.connectFuncName = 'connectSSL'
+            self.connectFuncArgs = (context,)
         return host, port
 
 def make_session(pint, attrs, session_type='BOSH'):
     """
     pint  - punjab session interface class
     attrs - attributes sent from the body tag
-    """    
-
-    # this may need some work, idea, code taken from twisted.web.server
-    pint.counter = pint.counter + 1
-    sid  = md5.new("%s_%s_%s" % (str(time.time()), str(random.random()) , str(pint.counter))).hexdigest()
-
+    """
 
-    s    = Session(pint, sid, attrs)
-    
-    s.addBootstrap(xmlstream.STREAM_START_EVENT, s.streamStart)
-    s.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, s.connectEvent)
-    s.addBootstrap(xmlstream.STREAM_ERROR_EVENT, s.streamError)
-    s.addBootstrap(xmlstream.STREAM_END_EVENT, s.connectError)    
-    
-    s.inactivity = int(attrs.get('inactivity', 900)) # default inactivity 15 mins
-    
-    s.secure = 0
-    s.use_raw = getattr(pint, 'use_raw', False) # use raw buffers
-    
-    if attrs.has_key('secure') and attrs['secure'] == 'true':
-        s.secure = 1
-        s.authenticator.useTls = 1
-    else:
-        s.authenticator.useTls = 0
+    s    = Session(pint, attrs)
 
     if pint.v:
         log.msg('================================== %s connect to %s:%s ==================================' % (str(time.time()),s.hostname,s.port))
-        
+
     connect_srv = True
     if attrs.has_key('route'):
         connect_srv = False
@@ -100,10 +80,10 @@ def make_session(pint, attrs, session_type='BOSH'):
     # timeout
     reactor.callLater(s.inactivity, s.checkExpired)
 
-    pint.sessions[sid] = s
-    
+    pint.sessions[s.sid] = s
+
     return s, s.waiting_requests[0].deferred
-    
+
 
 class WaitingRequest(object):
     """A helper object for managing waiting requests."""
@@ -128,7 +108,7 @@ class WaitingRequest(object):
 
 class Session(jabber.JabberClientFactory, server.Session):
     """ Jabber Client Session class for client XMPP connections. """
-    def __init__(self, pint, sid, attrs):
+    def __init__(self, pint, attrs):
         """
         Initialize the session
         """
@@ -150,11 +130,12 @@ class Session(jabber.JabberClientFactory, server.Session):
             else:
                 self.port = 5222
         
+        self.sid = "".join("%02x" % ord(i) for i in os.urandom(20))
+
         jabber.JabberClientFactory.__init__(self, self.to, pint.v)
-        server.Session.__init__(self, pint, sid)
+        server.Session.__init__(self, pint, self.sid)
         self.pint  = pint
 
-        self.sid   = sid
         self.attrs = attrs
         self.s     = None
 
@@ -167,7 +148,6 @@ class Session(jabber.JabberClientFactory, server.Session):
         self.raw_buffer = u""
         self.xmpp_node  = ''       
         self.success    = 0        
-        self.secure     = 0
         self.mechanisms = []
         self.xmlstream  = None
         self.features   = None
@@ -179,13 +159,12 @@ class Session(jabber.JabberClientFactory, server.Session):
 
         self.version = attrs.get('version', 0.0)
                 
-        if attrs.has_key('newkey'):
-            newkey   = attrs['newkey']
-            self.key = newkey
+        self.key = attrs.get('newkey')
         
         self.wait  = int(attrs.get('wait', 0))            
 
         self.hold  = int(attrs.get('hold', 0))
+        self.inactivity = int(attrs.get('inactivity', 900)) # default inactivity 15 mins
 
         if attrs.has_key('window'):
             self.window  = int(attrs['window'])
@@ -205,6 +184,11 @@ class Session(jabber.JabberClientFactory, server.Session):
         else:
             self.hostname = self.to
             
+        self.use_raw = getattr(pint, 'use_raw', False) # use raw buffers
+
+        self.secure = attrs.has_key('secure') and attrs['secure'] == 'true'
+        self.authenticator.useTls = self.secure
+
         if attrs.has_key('route'):
             if attrs['route'].startswith("xmpp:"):
                 self.route = attrs['route'][5:]
@@ -229,6 +213,11 @@ class Session(jabber.JabberClientFactory, server.Session):
         if pint.v:
             log.msg('Session Created : %s %s' % (str(self.sid),str(time.time()), ))
         
+        self.addBootstrap(xmlstream.STREAM_START_EVENT, self.streamStart)
+        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, self.connectEvent)
+        self.addBootstrap(xmlstream.STREAM_ERROR_EVENT, self.streamError)
+        self.addBootstrap(xmlstream.STREAM_END_EVENT, self.connectError)
+
         # create the first waiting request
         d = defer.Deferred()
         timeout = 30
@@ -308,10 +297,7 @@ class Session(jabber.JabberClientFactory, server.Session):
         if 'onExpire' in dir(self.pint):
             self.pint.onExpire(self.sid)
         if self.verbose and not getattr(self, 'terminated', False):
-            log.msg(self.sid)
-            log.msg(self.rid)
-            log.msg(self.waiting_requests)
-            log.msg('SESSION -> We have expired')
+            log.msg('SESSION -> We have expired', self.sid, self.rid, self.waiting_requests)
         self.disconnect()
     
     def terminate(self):
@@ -449,7 +435,7 @@ class Session(jabber.JabberClientFactory, server.Session):
                 
         # There is a tls initializer added by us, if it is available we need to try it
         if len(initializers)>0 and starttls:
-            self.secure = 1
+            self.secure = True
 
         if self.authid is None:
             self.authid = self.xmlstream.sid
@@ -517,13 +503,15 @@ class Session(jabber.JabberClientFactory, server.Session):
     def streamError(self, streamerror):
         """called when we get a stream:error stanza"""
         
-        try: # a workaround for a bug in twisted.words.protocols.jabber.error
-            err_elem = streamerror.value.getElement()
-            err_elem.toXml()
-        except: # no matter what the exception we just return None
-            err_elem = None
+        if isinstance(streamerror.value, jabber_protocol.error.StreamError):
+            # This is an actual stream:error.  Create a remote-stream-error to encapsulate it.
+            err_elem = getattr(streamerror.value, "element")
+            e = self.buildRemoteError(err_elem)
+        else:
+            # This is another error, such as an XML parsing error.  This isn't a stream:error,
+            # so expose it as remote-connection-failed.
+            e = error.Error('remote-connection-failed')
 
-        e = self.buildRemoteError(err_elem)
         do_expire = True
         
         if len(self.waiting_requests) > 0:
@@ -546,9 +534,9 @@ class Session(jabber.JabberClientFactory, server.Session):
     def connectError(self, xs):
         """called when we get disconnected"""
         
-        # FIXME: we should really only send the error event back if
-        # attempts to reconnect fail.  There's no reason temporary
-        # connection failures should be exposed upstream
+        # If the connection was established and lost, then we need to report the error
+        # back to the client, since he needs to reauthenticate.  FIXME: If the connection was
+        # lost before anything happened, we could silently retry instead.
         if self.verbose:
             log.msg('connect ERROR')
             try:
@@ -557,17 +545,28 @@ class Session(jabber.JabberClientFactory, server.Session):
             except:
                 pass
             
+
+        self.stopTrying()
+
+        e = error.Error('remote-connection-failed')
+
+        do_expire = True
+
         if self.waiting_requests:
-                        
-            if len(self.waiting_requests) > 0:
-                wr = self.waiting_requests.pop(0)
-                wr.doErrback(error.Error('remote-connection-failed'))
+            wr = self.waiting_requests.pop(0)
+            wr.doErrback(e)
+        else: # need to wait for a new request and then expire
+            do_expire = False
 
         if self.pint and self.pint.sessions.has_key(self.sid):
-            try:
-                self.expire()
-            except:
-                self.onExpire()
+            if do_expire:
+                try:
+                    self.expire()
+                except:
+                    self.onExpire()
+            else:
+                s = self.pint.sessions.get(self.sid)
+                s.stream_error = e
 
 
     def sendRawXml(self, obj):
index 393ba49..b9063d1 100644 (file)
@@ -179,10 +179,20 @@ class XMPPServerProtocol(xmlstream.XmlStream):
         self.send("""<challenge xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>cmVhbG09ImNoZXNzcGFyay5jb20iLG5vbmNlPSJ0YUhIM0FHQkpQSE40eXNvNEt5cFlBPT0iLHFvcD0iYXV0aCxhdXRoLWludCIsY2hhcnNldD11dGYtOCxhbGdvcml0aG09bWQ1LXNlc3M=</challenge>""")
 
 
+    def triggerInvalidXML(self):
+        """Send invalid XML, to trigger a parse error."""
+        self.send("""<parse error=>""")
+        self.streamEnded(None)
+
     def triggerStreamError(self):
         """ send a stream error
         """
-        self.send("""<stream:error xmlns:stream='http://etherx.jabber.org/streams'><policy-violation xmlns='urn:ietf:params:xml:ns:xmpp-streams'/></stream:error>
+        self.send("""
+        <stream:error attrib="1" xmlns:stream='http://etherx.jabber.org/streams'>
+            <policy-violation xmlns='urn:ietf:params:xml:ns:xmpp-streams'/>
+            <text xmlns='urn:ietf:params:xml:ns:xmpp-streams' xml:lang='langcode'>Error text</text>
+            <arbitrary-extension val='2'/>
+        </stream:error>
 """)
         self.streamEnded(None)
 
index 0eb6b1e..607f68c 100644 (file)
@@ -14,18 +14,18 @@ from punjab.xmpp import server as xmppserver
 from punjab import httpb_client
 
 class DummyTransport:
-    
+
     def __init__(self):
         self.data = []
-              
+
     def write(self, bytes):
         self.data.append(bytes)
-       
+
     def loseConnection(self, *args, **kwargs):
         self.data = []
 
 class TestCase(unittest.TestCase):
-    """Basic test class for Punjab 
+    """Basic test class for Punjab
     """
 
     def setUp(self):
@@ -38,21 +38,22 @@ class TestCase(unittest.TestCase):
         self.root.putChild('xmpp-bosh', self.b)
 
         self.site  = server.Site(self.root)
-        
+
         self.p =  reactor.listenTCP(0, self.site, interface="127.0.0.1")
         self.port = self.p.getHost().port
 
         # set up proxy
-        
+
         self.proxy = httpb_client.Proxy(self.getURL())
         self.sid   = None
         self.keys  = httpb_client.Keys()
 
         # set up dummy xmpp server
-        
+
         self.server_service = xmppserver.XMPPServerService()
         self.server_factory = xmppserver.IXMPPServerFactory(self.server_service)
-        self.server = reactor.listenTCP(5222, self.server_factory, interface="127.0.0.1")
+        self.server = reactor.listenTCP(0, self.server_factory, interface="127.0.0.1")
+        self.server_port = self.server.socket.getsockname()[1]
 
         # Hook the server's buildProtocol to make the protocol instance
         # accessible to tests.
@@ -72,20 +73,20 @@ class TestCase(unittest.TestCase):
 
 
     def key(self,b):
-        if self.keys.lastKey():
-            self.keys.setKeys()
-        
-        if self.keys.firstKey():
-            b['newkey'] = self.keys.getKey()
-        else:
-            b['key'] = self.keys.getKey()
-        return b 
+        key, newkey = self.keys.getKey()
+
+        if key:
+            b['key'] = key
+        if newkey:
+            b['newkey'] = newkey
+
+        return b
 
     def resend(self, ext = None):
         self.rid = self.rid - 1
         return self.send(ext)
 
-    def send(self, ext = None, sid = None, rid = None):
+    def get_body_node(self, ext=None, sid=None, rid=None, useKey=False, connect=False, **kwargs):
         self.rid = self.rid + 1
         if sid is None:
             sid = self.sid
@@ -93,10 +94,22 @@ class TestCase(unittest.TestCase):
             rid = self.rid
         b = domish.Element(("http://jabber.org/protocol/httpbind","body"))
         b['content']  = 'text/xml; charset=utf-8'
-
-        b['rid']      = str(rid)
-        b['sid']      = str(sid)
+        b['hold'] = '0'
+        b['wait'] = '60'
+        b['ack'] = '1'
         b['xml:lang'] = 'en'
+        b['rid'] = str(rid)
+
+        if sid:
+            b['sid'] = str(sid)
+
+        if connect:
+            b['to'] = 'localhost'
+            b['route'] = 'xmpp:127.0.0.1:%i' % self.server_port
+            b['ver'] = '1.6'
+
+        if useKey:
+            self.key(b)
 
         if ext is not None:
             if isinstance(ext, domish.Element):
@@ -104,10 +117,26 @@ class TestCase(unittest.TestCase):
             else:
                 b.addRawXml(ext)
 
-        b = self.key(b)
+        for key, value in kwargs.iteritems():
+            b[key] = value
+        return b
+
+    def send(self, ext = None, sid = None, rid = None):
+        b = self.get_body_node(ext, sid, rid)
         d = self.proxy.send(b)
         return d
 
+    def _storeSID(self, res):
+        self.sid = res[0]['sid']
+        return res
+
+    def connect(self, b):
+        d = self.proxy.connect(b)
+        # If we don't already have a SID, store the one we get back.
+        if not self.sid:
+            d.addCallback(self._storeSID)
+        return d
+
         
 
     def _error(self, e):
@@ -151,10 +180,11 @@ class TestCase(unittest.TestCase):
                 self.b.service.endSession(sess)
         if hasattr(self.proxy.factory,'client'):
             self.proxy.factory.client.transport.stopConnecting()
+        self.server_factory.protocol.delay_features = 0
         
 
         d = defer.maybeDeferred(self.server.stopListening)
         d.addCallback(cbStopListening)
 
         return d
-        
+
index b1e2f15..663ffa6 100644 (file)
@@ -1,6 +1,6 @@
 
 import os
-import sys, sha, random
+import sys, random
 from twisted.trial import unittest
 import time
 from twisted.web import server, resource, static, http, client
index 7fca753..3d5c2a0 100644 (file)
@@ -1,18 +1,8 @@
-
-import os
-import sys, sha, random
-from twisted.trial import unittest
-import time
-from twisted.web import server, resource, static, http, client
-from twisted.words.protocols.jabber import jid
-from twisted.internet import defer, protocol, reactor
-from twisted.application import internet, service
-from twisted.words.xish import domish, xpath
+from twisted.internet import defer, reactor, task
+from twisted.words.xish import xpath
 
 from twisted.python import log
 
-from punjab.httpb import HttpbService
-from punjab.xmpp import server as xmppserver
 from punjab import httpb_client
 
 import test_basic
@@ -28,27 +18,27 @@ class XEP0124TestCase(test_basic.TestCase):
         """
         Test Section 7.1 of BOSH xep : http://www.xmpp.org/extensions/xep-0124.html#session
         """
-        
+
         def _testSessionCreate(res):
-            self.failUnless(res[0].name=='body', 'Wrong element')            
+            self.failUnless(res[0].name=='body', 'Wrong element')
             self.failUnless(res[0].hasAttribute('sid'), 'Not session id')
-            
+
         def _error(e):
-            # This fails on DNS 
+            # This fails on DNS
             log.err(e)
-            
 
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
       rid='1573741820'
       to='localhost'
+      route='xmpp:127.0.0.1:%(server_port)i'
       secure='true'
       ver='1.6'
       wait='60'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """
+ """% { "server_port": self.server_port }
 
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
         d.addErrback(_error)
@@ -59,28 +49,29 @@ class XEP0124TestCase(test_basic.TestCase):
         """
         Basic tests for whitelisting domains.
         """
-        
+
         def _testSessionCreate(res):
-            self.failUnless(res[0].name=='body', 'Wrong element')            
+            self.failUnless(res[0].name=='body', 'Wrong element')
             self.failUnless(res[0].hasAttribute('sid'), 'Not session id')
-            
+
         def _error(e):
-            # This fails on DNS 
+            # This fails on DNS
             log.err(e)
-            
+
         self.hbs.white_list = ['.localhost']
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
       rid='1573741820'
       to='localhost'
+      route='xmpp:127.0.0.1:%(server_port)i'
       secure='true'
       ver='1.6'
       wait='60'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """
-        
+ """% { "server_port": self.server_port }
+
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
         d.addErrback(_error)
         return d
@@ -89,26 +80,33 @@ class XEP0124TestCase(test_basic.TestCase):
         """
         Basic tests for whitelisting domains.
         """
-        
+
         def _testSessionCreate(res):
             self.fail("Session should not be created")
-            
+
         def _error(e):
-            return True
-            
+            # This is the error we expect.
+            if isinstance(e.value, ValueError) and e.value.args == ('400', 'Bad Request'):
+                return True
+
+            # Any other error, including the error raised from _testSessionCreate, should
+            # be propagated up to the test runner.
+            return e
+
         self.hbs.white_list = ['test']
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
       rid='1573741820'
       to='localhost'
+      route='xmpp:127.0.0.1:%(server_port)i'
       secure='true'
       ver='1.6'
       wait='60'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """
-        
+ """% { "server_port": self.server_port }
+
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
         d.addErrback(_error)
         return d
@@ -116,86 +114,78 @@ class XEP0124TestCase(test_basic.TestCase):
     def testSessionTimeout(self):
         """Test if we timeout correctly
         """
-        d = defer.Deferred()
 
         def testTimeout(res):
-            passed = True
-            
-            if res.value[0]!='404':
-                passed = False
-                d.errback((Exception, 'Wrong Value %s '% (str(res.value),)))
-            if passed:
-                d.callback(True)
-            else:
-                log.err(res)
+            self.failUnlessEqual(res.value[0], '404')
 
         def testCBTimeout(res):
-            # check for terminate if we expire 
+            # check for terminate if we expire
             terminate = res[0].getAttribute('type',False)
-            
-            if str(terminate) != 'terminate':
-                d.errback((Exception, 'Was not terminate'))
-                return
-            d.callback(True)
+            self.failUnlessEqual(terminate, 'terminate')
 
         def sendTest():
             sd = self.send()
             sd.addCallback(testCBTimeout)
             sd.addErrback(testTimeout)
-            
+            return sd
 
         def testResend(res):
             self.failUnless(res[0].name=='body', 'Wrong element')
             s = self.b.service.sessions[self.sid]
-            self.failUnless(s.inactivity==10,'Wrong inactivity value')
-            self.failUnless(s.wait==10, 'Wrong wait value')
-            reactor.callLater(s.wait+s.inactivity+1, sendTest)
-            
+            self.failUnless(s.inactivity==2,'Wrong inactivity value')
+            self.failUnless(s.wait==2, 'Wrong wait value')
+            return task.deferLater(reactor, s.wait+s.inactivity+1, sendTest)
 
         def testSessionCreate(res):
-            self.failUnless(res[0].name=='body', 'Wrong element')            
+            self.failUnless(res[0].name=='body', 'Wrong element')
             self.failUnless(res[0].hasAttribute('sid'),'Not session id')
             self.sid = res[0]['sid']
 
-            # send and wait 
+            # send and wait
             sd = self.send()
-            
             sd.addCallback(testResend)
-            
+            return sd
+
 
 
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
-      rid='%d'
+      rid='%(rid)i'
       to='localhost'
-      route='xmpp:127.0.0.1:5222'
+      route='xmpp:127.0.0.1:%(server_port)i'
       ver='1.6'
-      wait='10'
+      wait='2'
       ack='1'
-      inactivity='10'
+      inactivity='2'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """% (self.rid,)
+ """% { "rid": self.rid, "server_port": self.server_port }
 
-        self.proxy.connect(BOSH_XML).addCallback(testSessionCreate)
-        d.addErrback(self.fail)
-        return d
+        return self.proxy.connect(BOSH_XML).addCallbacks(testSessionCreate)
 
     def testStreamError(self):
         """
         This is to test if we get stream errors when there are no waiting requests.
         """
-        
+
         def _testStreamError(res):
-            self.failUnless(res.value[0].hasAttribute('condition'), 'No attribute condition')
-            self.failUnless(res.value[0]['condition'] == 'remote-stream-error', 'Condition should be remote stream error')
-            self.failUnless(res.value[1][0].children[0].name == 'policy-violation', 'Error should be policy violation')
+            if not isinstance(res.value, httpb_client.HTTPBNetworkTerminated):
+                return res
+
+            self.failUnless(res.value.body_tag.hasAttribute('condition'), 'No attribute condition')
+            self.failUnlessEqual(res.value.body_tag['condition'], 'remote-connection-failed')
+
+            # The XML should exactly match the error XML sent by triggerStreamError().
+            self.failUnless(xpath.XPathQuery("/error[@attrib='1']").matches(res.value.elements[0]))
+            self.failUnless(xpath.XPathQuery("/error/policy-violation").matches(res.value.elements[0]))
+            self.failUnless(xpath.XPathQuery("/error/arbitrary-extension").matches(res.value.elements[0]))
+            self.failUnless(xpath.XPathQuery("/error/text[text() = 'Error text']").matches(res.value.elements[0]))
 
 
 
         def _failStreamError(res):
             self.fail('A stream error needs to be returned')
-            
+
         def _testSessionCreate(res):
             self.sid = res[0]['sid']
             # this xml is valid, just for testing
@@ -206,30 +196,169 @@ class XEP0124TestCase(test_basic.TestCase):
             self.server_protocol.triggerStreamError()
 
             return d
-            
+
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
-      rid='%d'
+      rid='%(rid)i'
       to='localhost'
-      route='xmpp:127.0.0.1:5222'
+      route='xmpp:127.0.0.1:%(server_port)i'
       ver='1.6'
       wait='60'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """ % (self.rid,)
+ """% { "rid": self.rid, "server_port": self.server_port }
 
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
 
         return d
 
 
+    @defer.inlineCallbacks
+    def testStreamFlushOnError(self):
+        """
+        Test that messages included in a <body type='terminate'> message from the
+        client are sent to the server before terminating.
+        """
+        yield self.connect(self.get_body_node(connect=True))
+
+        # Set got_testing_node to true when the XMPP server receives the <testing/> we
+        # send below.
+        got_testing_node = [False] # work around Python's 2.6 lack of nonlocal
+        wait = defer.Deferred()
+        def received_testing(a):
+            got_testing_node[0] = True
+            wait.callback(True)
+        self.server_protocol.addObserver("/testing", received_testing)
+
+        # Ensure that we always remove the received_testing listener.
+        try:
+            # Send <body type='terminate'><testing/></body>.  This should result in a
+            # HTTPBNetworkTerminated exception.
+            try:
+                yield self.proxy.send(self.get_body_node(ext='<testing/>', type='terminate'))
+            except httpb_client.HTTPBNetworkTerminated as e:
+                self.failUnlessEqual(e.body_tag.getAttribute('condition', None), None)
+
+            # Wait until <testing/> is actually received by the XMPP server.  The previous
+            # request completing only means that the proxy has received the stanza, not that
+            # it's been delivered to the XMPP server.
+            yield wait
+
+        finally:
+            self.server_protocol.removeObserver("/testing", received_testing)
+
+        # This should always be true, or we'd never have woken up from wait.
+        self.failUnless(got_testing_node[0])
+
+    @defer.inlineCallbacks
+    def testTerminateRace(self):
+        """Test that buffered messages are flushed when the connection is terminated."""
+        yield self.connect(self.get_body_node(connect=True))
+
+        def log_observer(event):
+            self.failIf(event['isError'], event)
+
+        log.addObserver(log_observer)
+
+        # Simultaneously cause a stream error (server->client closed) and send a terminate
+        # from the client to the server.  Both sides are closing the connection at once.
+        # Make sure the connection closes cleanly without logging any errors ("Unhandled
+        # Error"), and the client receives a terminate in response.
+        try:
+            self.server_protocol.triggerStreamError()
+            yield self.proxy.send(self.get_body_node(type='terminate'))
+        except httpb_client.HTTPBNetworkTerminated as e:
+            self.failUnlessEqual(e.body_tag.getAttribute('condition', None), 'remote-stream-error')
+        finally:
+            log.removeObserver(log_observer)
+
+    @defer.inlineCallbacks
+    def testStreamKeying1(self):
+        """Test that connections succeed when stream keying is active."""
+
+        yield self.connect(self.get_body_node(connect=True, useKey=True))
+        yield self.proxy.send(self.get_body_node(useKey=True))
+        yield self.proxy.send(self.get_body_node(useKey=True))
+
+    @defer.inlineCallbacks
+    def testStreamKeying2(self):
+        """Test that 404 is received if stream keying is active and no key is supplied."""
+        yield self.connect(self.get_body_node(connect=True, useKey=True))
+
+        try:
+            yield self.proxy.send(self.get_body_node(useKey=False))
+        except httpb_client.HTTPBNetworkTerminated as e:
+            self.failUnlessEqual(e.body_tag.getAttribute('condition', None), 'item-not-found')
+        else:
+            self.fail("Expected 404 Not Found")
+
+
+    @defer.inlineCallbacks
+    def testStreamKeying3(self):
+        """Test that 404 is received if stream keying is active and an invalid key is supplied."""
+        yield self.connect(self.get_body_node(connect=True, useKey=True))
+
+        try:
+            yield self.proxy.send(self.get_body_node(useKey=True, key='0'*40))
+        except httpb_client.HTTPBNetworkTerminated as e:
+            self.failUnlessEqual(e.body_tag.getAttribute('condition', None), 'item-not-found')
+        else:
+            self.fail("Expected 404 Not Found")
+
+
+    @defer.inlineCallbacks
+    def testStreamKeying4(self):
+        """Test that 404 is received if we supply a key on a connection without active keying."""
+        yield self.connect(self.get_body_node(connect=True, useKey=False))
+
+        try:
+            yield self.proxy.send(self.get_body_node(key='0'*40))
+        except httpb_client.HTTPBNetworkTerminated as e:
+            self.failUnlessEqual(e.body_tag.getAttribute('condition', None), 'item-not-found')
+        else:
+            self.fail("Expected 404 Not Found")
+
+    @defer.inlineCallbacks
+    def testStreamKeying5(self):
+        """Test rekeying."""
+        yield self.connect(self.get_body_node(connect=True, useKey=True))
+        yield self.proxy.send(self.get_body_node(useKey=True))
+
+        # Erase all but the last key to force a rekeying.
+        self.keys.k = [self.keys.k[-1]]
+
+        yield self.proxy.send(self.get_body_node(useKey=True))
+        yield self.proxy.send(self.get_body_node(useKey=True))
+
+
+    def testStreamParseError(self):
+        """
+        Test that remote-connection-failed is received when the proxy receives invalid XML
+        from the XMPP server.
+        """
+
+        def _testStreamError(res):
+            if not isinstance(res.value, httpb_client.HTTPBNetworkTerminated):
+                return res
+
+            self.failUnlessEqual(res.value.body_tag.getAttribute('condition', None), 'remote-connection-failed')
+
+        def _failStreamError(res):
+            self.fail('Expected a remote-connection-failed error')
+
+        def _testSessionCreate(res):
+            self.sid = res[0]['sid']
+            self.server_protocol.triggerInvalidXML()
+            return self.send().addCallbacks(_failStreamError, _testStreamError)
+
+        return self.proxy.connect(self.get_body_node(connect=True)).addCallback(_testSessionCreate)
 
     def testFeaturesError(self):
         """
         This is to test if we get stream features and NOT twice
         """
-        
+
         def _testError(res):
             self.failUnless(res[1][0].name=='challenge','Did not get correct challenge stanza')
 
@@ -238,29 +367,28 @@ class XEP0124TestCase(test_basic.TestCase):
             # this xml is valid, just for testing
             # the point is to wait for a stream error
             self.failUnless(res[1][0].name=='features','Did not get initial features')
-            
-            # self.send("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
-            d = self.send("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>") 
+
+            d = self.send("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
             d.addCallback(_testError)
             reactor.callLater(1.1, self.server_protocol.triggerChallenge)
             return d
-            
+
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
-      rid='%d'
+      rid='%(rid)i'
       to='localhost'
-      route='xmpp:127.0.0.1:5222'
+      route='xmpp:127.0.0.1:%(server_port)i'
       ver='1.6'
       wait='15'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """ % (self.rid,)
+ """% { "rid": self.rid, "server_port": self.server_port }
         self.server_factory.protocol.delay_features = 3
 
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
         # NOTE : to trigger this bug there needs to be 0 waiting requests.
-        
+
         return d
 
 
@@ -277,39 +405,38 @@ class XEP0124TestCase(test_basic.TestCase):
             # resend auth
             for r in range(5):
                 res = yield self.resend("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
-            
+
             res = yield self.resend("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
-                
+
 
         def _testSessionCreate(res):
             self.sid = res[0]['sid']
             # this xml is valid, just for testing
             # the point is to wait for a stream error
             self.failUnless(res[1][0].name=='features','Did not get initial features')
-            
-            # self.send("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
+
             d = self.send("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='DIGEST-MD5'/>")
             d.addCallback(_testError)
             reactor.callLater(1, self.server_protocol.triggerChallenge)
 
             return d
-            
+
         BOSH_XML = """<body content='text/xml; charset=utf-8'
       hold='1'
-      rid='%d'
+      rid='%(rid)i'
       to='localhost'
-      route='xmpp:127.0.0.1:5222'
+      route='xmpp:127.0.0.1:%(server_port)i'
       ver='1.6'
       wait='3'
       ack='1'
       xml:lang='en'
       xmlns='http://jabber.org/protocol/httpbind'/>
- """ % (self.rid,)
+ """% { "rid": self.rid, "server_port": self.server_port }
 
         self.server_factory.protocol.delay_features = 10
         d = self.proxy.connect(BOSH_XML).addCallback(_testSessionCreate)
         # NOTE : to trigger this bug there needs to be 0 waiting requests.
-        
+
         return d
-        
+
 
index 52bd671..05a7b4d 100644 (file)
@@ -1,6 +1,6 @@
 
 import os
-import sys, sha
+import sys
 from twisted.trial import unittest
 import time
 from twisted.words.protocols.jabber import jid