Report the original stream:error stanza to clients.
[punjab-krb5-preauth] / punjab / session.py
1 """
2  session stuff for jabber connections
3
4 """
5 from twisted.internet import defer,  reactor
6 from twisted.python import log
7 from twisted.web import server
8 from twisted.names.srvconnect import SRVConnector
9
10 try:
11     from twisted.words.xish import domish, xmlstream
12 except ImportError:
13     from twisted.xish import domish, xmlstream
14
15
16 import traceback
17 import random
18 import md5
19 from punjab import jabber
20 from punjab.xmpp import ns
21
22 import time
23 import error
24
25 try:
26     from twisted.internet import ssl
27 except ImportError:
28     log.msg("SSL ERROR: You do not have ssl support this may cause problems with tls client connections.")
29
30
31
32 class XMPPClientConnector(SRVConnector):
33     """
34     A jabber connection to find srv records for xmpp client connections.
35     """
36     def __init__(self, client_reactor, domain, factory):
37         """ Init """
38         SRVConnector.__init__(self, client_reactor, 'xmpp-client', domain, factory)
39         self.timeout = [1,3]
40
41     def pickServer(self):
42         """
43         Pick a server and port to make the connection.
44         """
45         host, port = SRVConnector.pickServer(self)
46
47         if not self.servers and not self.orderedServers:
48             # no SRV record, fall back..
49             port = 5222
50         if port == 5223 and xmlstream.ssl:
51             context = xmlstream.ssl.ClientContextFactory()
52             context.method = xmlstream.ssl.SSL.SSLv23_METHOD
53             
54             self.connectFunc = 'connectSSL'
55             self.connectFuncArgs = (context)
56         return host, port
57
58 def make_session(pint, attrs, session_type='BOSH'):
59     """
60     pint  - punjab session interface class
61     attrs - attributes sent from the body tag
62     """    
63
64     # this may need some work, idea, code taken from twisted.web.server
65     pint.counter = pint.counter + 1
66     sid  = md5.new("%s_%s_%s" % (str(time.time()), str(random.random()) , str(pint.counter))).hexdigest()
67
68
69     s    = Session(pint, sid, attrs)
70     
71     s.addBootstrap(xmlstream.STREAM_START_EVENT, s.streamStart)
72     s.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, s.connectEvent)
73     s.addBootstrap(xmlstream.STREAM_ERROR_EVENT, s.streamError)
74     s.addBootstrap(xmlstream.STREAM_END_EVENT, s.connectError)    
75     
76     s.inactivity = int(attrs.get('inactivity', 900)) # default inactivity 15 mins
77     
78     s.secure = 0
79     s.use_raw = getattr(pint, 'use_raw', False) # use raw buffers
80     
81     if attrs.has_key('secure') and attrs['secure'] == 'true':
82         s.secure = 1
83         s.authenticator.useTls = 1
84     else:
85         s.authenticator.useTls = 0
86
87     if pint.v:
88         log.msg('================================== %s connect to %s:%s ==================================' % (str(time.time()),s.hostname,s.port))
89         
90     connect_srv = True
91     if attrs.has_key('route'):
92         connect_srv = False
93     if s.hostname in ['localhost', '127.0.0.1']:
94         connect_srv = False
95     if not connect_srv:
96         reactor.connectTCP(s.hostname, s.port, s, bindAddress=pint.bindAddress)
97     else:
98         connector = XMPPClientConnector(reactor, s.hostname, s)
99         connector.connect()
100     # timeout
101     reactor.callLater(s.inactivity, s.checkExpired)
102
103     pint.sessions[sid] = s
104     
105     return s, s.waiting_requests[0].deferred
106     
107
108 class WaitingRequest(object):
109     """A helper object for managing waiting requests."""
110
111     def __init__(self, deferred, delayedcall, timeout = 30, startup = False, rid = None):
112         """ """
113         self.deferred    = deferred
114         self.delayedcall = delayedcall
115         self.startup     = startup
116         self.timeout     = timeout
117         self.wait_start  = time.time()
118         self.rid         = rid
119         
120     def doCallback(self, data):
121         """ """
122         self.deferred.callback(data)
123
124     def doErrback(self, data):
125         """ """
126         self.deferred.errback(data)
127
128
129 class Session(jabber.JabberClientFactory, server.Session):
130     """ Jabber Client Session class for client XMPP connections. """
131     def __init__(self, pint, sid, attrs):
132         """
133         Initialize the session
134         """
135         if attrs.has_key('charset'):
136             self.charset = str(attrs['charset'])
137         else:
138             self.charset = 'utf-8'
139         
140         self.to    = attrs['to']
141         self.port  = 5222
142         self.inactivity = 900
143         if self.to != '' and self.to.find(":") != -1:
144             # Check if port is in the 'to' string
145             to, port = self.to.split(':')
146             
147             if port:
148                 self.to   = to
149                 self.port = int(port)
150             else:
151                 self.port = 5222
152         
153         jabber.JabberClientFactory.__init__(self, self.to, pint.v)
154         server.Session.__init__(self, pint, sid)
155         self.pint  = pint
156
157         self.sid   = sid
158         self.attrs = attrs
159         self.s     = None
160
161         self.elems = []
162         rid        = int(attrs['rid'])
163
164         self.waiting_requests = []
165         self.use_raw = attrs.get('raw', False)
166
167         self.raw_buffer = u""
168         self.xmpp_node  = ''       
169         self.success    = 0        
170         self.secure     = 0
171         self.mechanisms = []
172         self.xmlstream  = None
173         self.features   = None
174         self.session    = None
175         
176         self.cache_data = {}
177         self.verbose    = self.pint.v
178         self.noisy      = self.verbose
179
180         self.version = attrs.get('version', 0.0)
181                 
182         if attrs.has_key('newkey'):
183             newkey   = attrs['newkey']
184             self.key = newkey
185         
186         self.wait  = int(attrs.get('wait', 0))            
187
188         self.hold  = int(attrs.get('hold', 0))
189
190         if attrs.has_key('window'):
191             self.window  = int(attrs['window'])
192         else:
193             self.window  = self.hold + 2
194
195         if attrs.has_key('polling'):
196             self.polling  = int(attrs['polling'])
197         else:
198             self.polling  = 0
199            
200         if attrs.has_key('port'):
201             self.port = int(attrs['port'])
202
203         if attrs.has_key('hostname'):
204             self.hostname = attrs['hostname']
205         else:
206             self.hostname = self.to
207             
208         if attrs.has_key('route'):
209             if attrs['route'].startswith("xmpp:"):
210                 self.route = attrs['route'][5:]
211                 if self.route.startswith("//"):
212                     self.route = self.route[2:]
213
214                 # route format change, see http://www.xmpp.org/extensions/xep-0124.html#session-request
215                 rhostname, rport = self.route.split(":")
216                 self.port = int(rport)
217                 self.hostname = rhostname
218                 self.resource = ''
219             else:
220                 raise error.Error('internal-server-error')
221
222             
223         self.authid      = 0    
224         self.rid         = rid + 1
225         self.connected   = 0 # number of clients connected on this session 
226
227         self.notifyOnExpire(self.onExpire)
228         self.stream_error = None
229         if pint.v:
230             log.msg('Session Created : %s %s' % (str(self.sid),str(time.time()), ))
231         
232         # create the first waiting request
233         d = defer.Deferred()
234         timeout = 30
235         rid = self.rid - 1
236         self.appendWaitingRequest(d, rid, 
237                                   timeout=timeout, 
238                                   poll=self._startup_timeout,
239                                   startup=True,
240                                   )
241         
242     def rawDataIn(self, buf):
243         """ Log incoming data on the xmlstream """
244         if self.pint.v:
245             try:
246                 log.msg("SID: %s => RECV: %r" % (self.sid, buf,))
247             except:
248                 log.err()
249         if self.use_raw and self.authid:
250             if type(buf) == type(''):
251                 buf = unicode(buf, 'utf-8')
252             # add some raw data
253             self.raw_buffer = self.raw_buffer + buf
254
255         
256     def rawDataOut(self, buf):
257         """ Log outgoing data on the xmlstream """
258         try:
259             log.msg("SID: %s => SEND: %r" % (self.sid, buf,))
260         except:
261             log.err()
262
263     def _wrPop(self, data, i=0):
264         """Pop off a waiting requst, do callback, and cache request
265         """
266         wr = self.waiting_requests.pop(i)
267         wr.doCallback(data)
268         self._cacheData(wr.rid, data)
269
270     def clearWaitingRequests(self, hold = 0):
271         """clear number of requests given
272
273            hold - number of requests to clear, default is all
274         """ 
275         while len(self.waiting_requests) > hold:
276             self._wrPop([])
277
278     def _wrError(self, err, i = 0):
279         wr = self.waiting_requests.pop(i)
280         wr.doErrback(err)
281
282
283     def appendWaitingRequest(self, d, rid, timeout=None, poll=None, startup=False):
284         """append waiting request
285         """
286         if timeout is None:
287             timeout = self.wait
288         if poll is None:
289             poll = self._pollTimeout
290         self.waiting_requests.append(
291             WaitingRequest(d,
292                            poll,
293                            timeout = timeout,
294                            rid = rid,
295                            startup=startup))
296
297     def returnWaitingRequests(self):
298         """return a waiting request
299         """
300         while len(self.elems) > 0 and len(self.waiting_requests) > 0:
301             data = self.elems
302             self.elems = []
303             self._wrPop(data)
304
305
306     def onExpire(self):
307         """ When the session expires call this. """
308         if 'onExpire' in dir(self.pint):
309             self.pint.onExpire(self.sid)
310         if self.verbose and not getattr(self, 'terminated', False):
311             log.msg(self.sid)
312             log.msg(self.rid)
313             log.msg(self.waiting_requests)
314             log.msg('SESSION -> We have expired')
315         self.disconnect()
316     
317     def terminate(self):
318         """Terminates the session."""
319         self.wait = 0
320         self.terminated = True
321         if self.verbose:
322             log.msg('SESSION -> Terminate')
323         
324         # if there are any elements hanging around and waiting
325         # requests, send those off
326         self.returnWaitingRequests()
327         
328         self.clearWaitingRequests()
329
330         try:
331             self.expire()
332         except:
333             self.onExpire()
334         
335                 
336         return defer.succeed(self.elems)
337
338     def poll(self, d = None, rid = None):
339         """Handles the responses to requests.
340
341         This function is called for every request except session setup
342         and session termination.  It handles the reply portion of the
343         request by returning a deferred which will get called back
344         when there is data or when the wait timeout expires.
345         """
346         # queue this request
347         if d is None:
348             d = defer.Deferred()
349         if self.pint.error:
350             d.addErrback(self.pint.error)    
351         if not rid:
352             rid = self.rid - 1
353         self.appendWaitingRequest(d, rid)
354         # check if there is any data to send back to a request
355         self.returnWaitingRequests()
356         
357         # make sure we aren't queueing too many requests
358         self.clearWaitingRequests(self.hold)
359         return d
360
361     def _pollTimeout(self, d):
362         """Handle request timeouts.
363
364         Since the timeout function is called, we must return an empty
365         reply as there is no data to send back.
366         """
367         # find the request that timed out and reply
368         pop_eye = []
369         for i in range(len(self.waiting_requests)):
370             if self.waiting_requests[i].deferred == d:
371                 pop_eye.append(i)
372                 self.touch()
373
374         for i in pop_eye:
375             self._wrPop([],i)
376
377
378     def _pollForId(self, d):
379         if self.xmlstream.sid:
380             self.authid = self.xmlstream.sid
381         self._pollTimeout(d)
382         
383
384         
385     def connectEvent(self, xs):
386
387         self.version =  self.authenticator.version
388         self.xmlstream = xs
389         if self.pint.v:
390             # add logging for verbose output
391             
392             self.xmlstream.rawDataOutFn = self.rawDataOut
393         self.xmlstream.rawDataInFn = self.rawDataIn
394
395         if self.version == '1.0':
396             self.xmlstream.addObserver("/features", self.featuresHandler)
397             
398
399             
400     def streamStart(self, xs):
401         """
402         A xmpp stream has started
403         """
404         # This is done to fix the stream id problem, I should submit a bug to twisted bugs
405         
406         try:
407             
408             self.authid    = self.xmlstream.sid
409             
410             if not self.attrs.has_key('no_events'):
411                 
412                 self.xmlstream.addOnetimeObserver("/auth", self.stanzaHandler)
413                 self.xmlstream.addOnetimeObserver("/response", self.stanzaHandler)
414                 self.xmlstream.addOnetimeObserver("/success", self._saslSuccess)
415                 self.xmlstream.addOnetimeObserver("/failure", self._saslError)                    
416                 
417                 self.xmlstream.addObserver("/iq/bind", self.bindHandler)
418                 self.xmlstream.addObserver("/bind", self.stanzaHandler)
419                 
420                 self.xmlstream.addObserver("/challenge", self.stanzaHandler)
421                 self.xmlstream.addObserver("/message",  self.stanzaHandler)
422                 self.xmlstream.addObserver("/iq",  self.stanzaHandler)
423                 self.xmlstream.addObserver("/presence",  self.stanzaHandler)
424                 # TODO - we should do something like this
425                 # self.xmlstream.addObserver("/*",  self.stanzaHandler)
426                 
427         except:
428             log.err(traceback.print_exc())
429             self._wrError(error.Error("remote-connection-failed"))
430             self.disconnect()
431             
432
433     def featuresHandler(self, f):
434         """
435         handle stream:features
436         """
437         f.prefixes   = ns.XMPP_PREFIXES.copy()
438         
439         #check for tls
440         self.f = {}
441         for feature in f.elements():
442             self.f[(feature.uri, feature.name)] = feature
443         
444         starttls = (ns.TLS_XMLNS, 'starttls') in self.f
445         
446         initializers   = getattr(self.xmlstream, 'initializers', [])
447         self.features = f
448         self.xmlstream.features = f
449                 
450         # There is a tls initializer added by us, if it is available we need to try it
451         if len(initializers)>0 and starttls:
452             self.secure = 1
453
454         if self.authid is None:
455             self.authid = self.xmlstream.sid
456         
457
458         # If we get tls, then we should start tls, wait and then return
459         # Here we wait, the tls initializer will start it
460         if starttls and self.secure:
461             if self.verbose:
462                 log.msg("Wait until starttls is completed.")
463                 log.msg(initializers)
464             return
465         self.elems.append(f)
466         if len(self.waiting_requests) > 0:
467             self.returnWaitingRequests()
468             self.elems = [] # reset elems
469             self.raw_buffer = u"" # reset raw buffer, features should not be in it
470     
471     def bindHandler(self, stz):
472         """bind debugger for punjab, this is temporary! """
473         if self.verbose:
474             try:
475                 log.msg('BIND: %s %s' % (str(self.sid), str(stz.bind.jid)))
476             except:
477                 log.err()
478         if self.use_raw:
479             self.raw_buffer = stz.toXml()
480         
481     def stanzaHandler(self, stz):
482         """generic stanza handler for httpbind and httppoll"""
483         stz.prefixes = ns.XMPP_PREFIXES
484         if self.use_raw and self.authid:
485             stz = domish.SerializedXML(self.raw_buffer)
486             self.raw_buffer = u""
487
488         self.elems.append(stz)            
489         if self.waiting_requests and len(self.waiting_requests) > 0:
490             # if there are any waiting requests, give them all the
491             # data so far, plus this new data
492             self.returnWaitingRequests()
493
494
495     def _startup_timeout(self, d):
496         # this can be called if connection failed, or if we connected
497         # but never got a stream features before the timeout
498         if self.pint.v:
499             log.msg('================================== %s %s startup timeout ==================================' % (str(self.sid), str(time.time()),))
500         for i in range(len(self.waiting_requests)):
501             if self.waiting_requests[i].deferred == d:
502                 # check if we really failed or not
503                 if self.authid:
504                     self._wrPop(self.elems, i=i)
505                 else:
506                     self._wrError(error.Error("remote-connection-failed"), i=i)
507                     
508     
509     def buildRemoteError(self, err_elem=None):
510         e = error.Error('remote-stream-error')
511         e.error_stanza = 'remote-stream-error'
512         e.children = []
513         if err_elem:
514             e.children.append(err_elem)            
515         return e
516
517     def streamError(self, streamerror):
518         """called when we get a stream:error stanza"""
519         
520         err_elem = getattr(streamerror.value, "element")
521
522         e = self.buildRemoteError(err_elem)
523         do_expire = True
524         
525         if len(self.waiting_requests) > 0:
526             wr = self.waiting_requests.pop(0)            
527             wr.doErrback(e)
528         else: # need to wait for a new request and then expire
529             do_expire = False
530             
531         if self.pint and self.pint.sessions.has_key(self.sid):
532             if do_expire:
533                 try:
534                     self.expire()
535                 except:
536                     self.onExpire()
537             else:
538                 s = self.pint.sessions.get(self.sid)
539                 s.stream_error = e
540                                 
541
542     def connectError(self, xs):
543         """called when we get disconnected"""
544         
545         # FIXME: we should really only send the error event back if
546         # attempts to reconnect fail.  There's no reason temporary
547         # connection failures should be exposed upstream
548         if self.verbose:
549             log.msg('connect ERROR')
550             try:
551                 log.msg(xs)
552                 
553             except:
554                 pass
555             
556         if self.waiting_requests:
557                         
558             if len(self.waiting_requests) > 0:
559                 wr = self.waiting_requests.pop(0)
560                 wr.doErrback(error.Error('remote-connection-failed'))
561
562         if self.pint and self.pint.sessions.has_key(self.sid):
563             try:
564                 self.expire()
565             except:
566                 self.onExpire()
567
568
569     def sendRawXml(self, obj):
570         """
571         Send a raw xml string, not a domish.Element
572         """
573         self.touch()
574         self._send(obj)
575     
576     
577     def _send(self, xml):
578         """
579         Send valid data over the xmlstream
580         """
581         if self.xmlstream: # FIXME this happens on an expired session and the post has something to send
582             if isinstance(xml, domish.Element):
583                 xml.localPrefixes = {}
584             self.xmlstream.send(xml)
585
586     def _removeObservers(self, typ = ''):
587         if typ == 'event':
588             observers = self.xmlstream._eventObservers
589         else:
590             observers = self.xmlstream._xpathObservers
591         emptyLists = []
592         for priority, priorityObservers in observers.iteritems():
593             for query, callbacklist in priorityObservers.iteritems():
594                 callbacklist.callbacks = []
595                 emptyLists.append((priority, query))
596
597         for priority, query in emptyLists:
598             del observers[priority][query]
599     
600     def disconnect(self):
601         """
602         Disconnect from the xmpp server.
603         """
604         if not getattr(self, 'xmlstream',None):
605             return
606         
607         if self.xmlstream:
608             #sh = "<presence type='unavailable' xmlns='jabber:client'/>"
609             sh = "</stream:stream>"
610             self.xmlstream.send(sh)
611             
612         self.stopTrying()
613         if self.xmlstream:
614             self.xmlstream.transport.loseConnection()
615          
616             del self.xmlstream
617         self.connected = 0
618         self.pint      = None
619         self.elems     = []
620         
621         if self.waiting_requests:
622             self.clearWaitingRequests()
623             del self.waiting_requests
624         self.mechanisms = None
625         self.features   = None
626         
627
628     
629     def checkExpired(self):
630         """
631         Check if the session or xmpp connection has expired
632         """
633         # send this so we do not timeout from servers
634         if getattr(self, 'xmlstream', None):
635             self.xmlstream.send(' ')
636         if self.inactivity is None:
637             wait = 900
638         elif self.inactivity == 0:
639             wait = time.time()
640         
641         else:
642             wait = self.inactivity
643
644         if self.waiting_requests and len(self.waiting_requests)>0:
645             wait += self.wait # if we have pending requests we need to add the wait time
646             
647         if time.time() - self.lastModified > wait+(0.1):
648             if self.site.sessions.has_key(self.uid):
649                 self.terminate()
650             else:
651                 pass
652              
653         else:
654             reactor.callLater(wait, self.checkExpired)
655
656
657     def _cacheData(self, rid, data):
658         if len(self.cache_data.keys())>=3:
659             # remove the first one in
660             keys = self.cache_data.keys()
661             keys.sort()
662             del self.cache_data[keys[0]]
663         
664         self.cache_data[int(rid)] = data
665         
666 # This stuff will leave when SASL and TLS are implemented correctly
667 # session stuff
668
669     def _sessionResultEvent(self, iq):
670         """ """
671         if len(self.waiting_requests)>0:                
672                 wr = self.waiting_requests.pop(0)
673                 d  = wr.deferred
674         else:
675                 d = None
676
677         if iq["type"] == "result":
678             if d:
679                 d.callback(self)
680         else:
681             if d:
682                 d.errback(self)
683
684
685     def _saslSuccess(self, s):
686         """ """
687         self.success = 1
688         self.s = s
689         # return success to the client
690         if len(self.waiting_requests)>0:
691             self._wrPop([s])
692
693         self.authenticator._reset()
694         if self.use_raw:
695             self.raw_buffer = u""
696
697
698
699     def _saslError(self, sasl_error, d = None):
700         """ SASL error """
701         
702         if d:
703             d.errback(self)
704         if len(self.waiting_requests)>0:                
705             self._wrPop([sasl_error])