root/flup/server/scgi_base.py

Revision 77:3454cb15a7dd, 18.2 kB (checked in by Allan Saddi <allan@saddi.com>, 2 months ago)

Add an indication as to which header fails assertion when
passing in non-string header names and/or values.

Line 
1 # Copyright (c) 2005, 2006 Allan Saddi <allan@saddi.com>
2 # All rights reserved.
3 #
4 # Redistribution and use in source and binary forms, with or without
5 # modification, are permitted provided that the following conditions
6 # are met:
7 # 1. Redistributions of source code must retain the above copyright
8 #    notice, this list of conditions and the following disclaimer.
9 # 2. Redistributions in binary form must reproduce the above copyright
10 #    notice, this list of conditions and the following disclaimer in the
11 #    documentation and/or other materials provided with the distribution.
12 #
13 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
14 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16 # ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
17 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
18 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
19 # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
20 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
21 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
22 # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
23 # SUCH DAMAGE.
24 #
25 # $Id$
26
27 __author__ = 'Allan Saddi <allan@saddi.com>'
28 __version__ = '$Revision$'
29
30 import sys
31 import logging
32 import socket
33 import select
34 import errno
35 import cStringIO as StringIO
36 import signal
37 import datetime
38 import os
39 import warnings
40
41 # Threads are required. If you want a non-threaded (forking) version, look at
42 # SWAP <http://www.idyll.org/~t/www-tools/wsgi/>.
43 import thread
44 import threading
45
46 __all__ = ['BaseSCGIServer']
47
48 class NoDefault(object):
49     pass
50
51 # The main classes use this name for logging.
52 LoggerName = 'scgi-wsgi'
53
54 # Set up module-level logger.
55 console = logging.StreamHandler()
56 console.setLevel(logging.DEBUG)
57 console.setFormatter(logging.Formatter('%(asctime)s : %(message)s',
58                                        '%Y-%m-%d %H:%M:%S'))
59 logging.getLogger(LoggerName).addHandler(console)
60 del console
61
62 class ProtocolError(Exception):
63     """
64     Exception raised when the server does something unexpected or
65     sends garbled data. Usually leads to a Connection closing.
66     """
67     pass
68
69 def recvall(sock, length):
70     """
71     Attempts to receive length bytes from a socket, blocking if necessary.
72     (Socket may be blocking or non-blocking.)
73     """
74     dataList = []
75     recvLen = 0
76     while length:
77         try:
78             data = sock.recv(length)
79         except socket.error, e:
80             if e[0] == errno.EAGAIN:
81                 select.select([sock], [], [])
82                 continue
83             else:
84                 raise
85         if not data: # EOF
86             break
87         dataList.append(data)
88         dataLen = len(data)
89         recvLen += dataLen
90         length -= dataLen
91     return ''.join(dataList), recvLen
92
93 def readNetstring(sock):
94     """
95     Attempt to read a netstring from a socket.
96     """
97     # First attempt to read the length.
98     size = ''
99     while True:
100         try:
101             c = sock.recv(1)
102         except socket.error, e:
103             if e[0] == errno.EAGAIN:
104                 select.select([sock], [], [])
105                 continue
106             else:
107                 raise
108         if c == ':':
109             break
110         if not c:
111             raise EOFError
112         size += c
113
114     # Try to decode the length.
115     try:
116         size = int(size)
117         if size < 0:
118             raise ValueError
119     except ValueError:
120         raise ProtocolError, 'invalid netstring length'
121
122     # Now read the string.
123     s, length = recvall(sock, size)
124
125     if length < size:
126         raise EOFError
127
128     # Lastly, the trailer.
129     trailer, length = recvall(sock, 1)
130
131     if length < 1:
132         raise EOFError
133
134     if trailer != ',':
135         raise ProtocolError, 'invalid netstring trailer'
136
137     return s
138
139 class StdoutWrapper(object):
140     """
141     Wrapper for sys.stdout so we know if data has actually been written.
142     """
143     def __init__(self, stdout):
144         self._file = stdout
145         self.dataWritten = False
146
147     def write(self, data):
148         if data:
149             self.dataWritten = True
150         self._file.write(data)
151
152     def writelines(self, lines):
153         for line in lines:
154             self.write(line)
155
156     def __getattr__(self, name):
157         return getattr(self._file, name)
158
159 class Request(object):
160     """
161     Encapsulates data related to a single request.
162
163     Public attributes:
164       environ - Environment variables from web server.
165       stdin - File-like object representing the request body.
166       stdout - File-like object for writing the response.
167     """
168     def __init__(self, conn, environ, input, output):
169         self._conn = conn
170         self.environ = environ
171         self.stdin = input
172         self.stdout = StdoutWrapper(output)
173
174         self.logger = logging.getLogger(LoggerName)
175
176     def run(self):
177         self.logger.info('%s %s%s',
178                          self.environ['REQUEST_METHOD'],
179                          self.environ.get('SCRIPT_NAME', ''),
180                          self.environ.get('PATH_INFO', ''))
181
182         start = datetime.datetime.now()
183
184         try:
185             self._conn.server.handler(self)
186         except:
187             self.logger.exception('Exception caught from handler')
188             if not self.stdout.dataWritten:
189                 self._conn.server.error(self)
190
191         end = datetime.datetime.now()
192
193         handlerTime = end - start
194         self.logger.debug('%s %s%s done (%.3f secs)',
195                           self.environ['REQUEST_METHOD'],
196                           self.environ.get('SCRIPT_NAME', ''),
197                           self.environ.get('PATH_INFO', ''),
198                           handlerTime.seconds +
199                           handlerTime.microseconds / 1000000.0)
200
201 class Connection(object):
202     """
203     Represents a single client (web server) connection. A single request
204     is handled, after which the socket is closed.
205     """
206     def __init__(self, sock, addr, server):
207         self._sock = sock
208         self._addr = addr
209         self.server = server
210
211         self.logger = logging.getLogger(LoggerName)
212
213     def run(self):
214         if len(self._addr) == 2:
215             self.logger.debug('Connection starting up (%s:%d)',
216                               self._addr[0], self._addr[1])
217
218         try:
219             self.processInput()
220         except (EOFError, KeyboardInterrupt):
221             pass
222         except ProtocolError, e:
223             self.logger.error("Protocol error '%s'", str(e))
224         except:
225             self.logger.exception('Exception caught in Connection')
226
227         if len(self._addr) == 2:
228             self.logger.debug('Connection shutting down (%s:%d)',
229                               self._addr[0], self._addr[1])
230
231         # All done!
232         self._sock.close()
233
234     def processInput(self):
235         # Read headers
236         headers = readNetstring(self._sock)
237         headers = headers.split('\x00')[:-1]
238         if len(headers) % 2 != 0:
239             raise ProtocolError, 'invalid headers'
240         environ = {}
241         for i in range(len(headers) / 2):
242             environ[headers[2*i]] = headers[2*i+1]
243
244         clen = environ.get('CONTENT_LENGTH')
245         if clen is None:
246             raise ProtocolError, 'missing CONTENT_LENGTH'
247         try:
248             clen = int(clen)
249             if clen < 0:
250                 raise ValueError
251         except ValueError:
252             raise ProtocolError, 'invalid CONTENT_LENGTH'
253
254         self._sock.setblocking(1)
255         if clen:
256             input = self._sock.makefile('r')
257         else:
258             # Empty input.
259             input = StringIO.StringIO()
260
261         # stdout
262         output = self._sock.makefile('w')
263
264         # Allocate Request
265         req = Request(self, environ, input, output)
266
267         # Run it.
268         req.run()
269
270         output.close()
271         input.close()
272
273 class BaseSCGIServer(object):
274     # What Request class to use.
275     requestClass = Request
276
277     def __init__(self, application, scriptName=NoDefault, environ=None,
278                  multithreaded=True, multiprocess=False,
279                  bindAddress=('localhost', 4000), umask=None,
280                  allowedServers=NoDefault,
281                  loggingLevel=logging.INFO, debug=True):
282         """
283         scriptName is the initial portion of the URL path that "belongs"
284         to your application. It is used to determine PATH_INFO (which doesn't
285         seem to be passed in). An empty scriptName means your application
286         is mounted at the root of your virtual host.
287
288         environ, which must be a dictionary, can contain any additional
289         environment variables you want to pass to your application.
290
291         Set multithreaded to False if your application is not thread-safe.
292
293         Set multiprocess to True to explicitly set wsgi.multiprocess to
294         True. (Only makes sense with threaded servers.)
295
296         bindAddress is the address to bind to, which must be a string or
297         a tuple of length 2. If a tuple, the first element must be a string,
298         which is the host name or IPv4 address of a local interface. The
299         2nd element of the tuple is the port number. If a string, it will
300         be interpreted as a filename and a UNIX socket will be opened.
301
302         If binding to a UNIX socket, umask may be set to specify what
303         the umask is to be changed to before the socket is created in the
304         filesystem. After the socket is created, the previous umask is
305         restored.
306         
307         allowedServers must be None or a list of strings representing the
308         IPv4 addresses of servers allowed to connect. None means accept
309         connections from anywhere. By default, it is a list containing
310         the single item '127.0.0.1'.
311
312         loggingLevel sets the logging level of the module-level logger.
313         """
314         if environ is None:
315             environ = {}
316
317         self.application = application
318         self.scriptName = scriptName
319         self.environ = environ
320         self.multithreaded = multithreaded
321         self.multiprocess = multiprocess
322         self.debug = debug
323         self._bindAddress = bindAddress
324         self._umask = umask
325         if allowedServers is NoDefault:
326             allowedServers = ['127.0.0.1']
327         self._allowedServers = allowedServers
328
329         # Used to force single-threadedness.
330         self._appLock = thread.allocate_lock()
331
332         self.logger = logging.getLogger(LoggerName)
333         self.logger.setLevel(loggingLevel)
334
335     def _setupSocket(self):
336         """Creates and binds the socket for communication with the server."""
337         oldUmask = None
338         if type(self._bindAddress) is str:
339             # Unix socket
340             sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
341             try:
342                 os.unlink(self._bindAddress)
343             except OSError:
344                 pass
345             if self._umask is not None:
346                 oldUmask = os.umask(self._umask)
347         else:
348             # INET socket
349             assert type(self._bindAddress) is tuple
350             assert len(self._bindAddress) == 2
351             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
352             sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
353
354         sock.bind(self._bindAddress)
355         sock.listen(socket.SOMAXCONN)
356
357         if oldUmask is not None:
358             os.umask(oldUmask)
359
360         return sock
361
362     def _cleanupSocket(self, sock):
363         """Closes the main socket."""
364         sock.close()
365
366     def _isClientAllowed(self, addr):
367         ret = self._allowedServers is None or \
368               len(addr) != 2 or \
369               (len(addr) == 2 and addr[0] in self._allowedServers)
370         if not ret:
371             self.logger.warning('Server connection from %s disallowed',
372                                 addr[0])
373         return ret
374
375     def handler(self, request):
376         """
377         WSGI handler. Sets up WSGI environment, calls the application,
378         and sends the application's response.
379         """
380         environ = request.environ
381         environ.update(self.environ)
382
383         environ['wsgi.version'] = (1,0)
384         environ['wsgi.input'] = request.stdin
385         environ['wsgi.errors'] = sys.stderr
386         environ['wsgi.multithread'] = self.multithreaded
387         environ['wsgi.multiprocess'] = self.multiprocess
388         environ['wsgi.run_once'] = False
389
390         if environ.get('HTTPS', 'off') in ('on', '1'):
391             environ['wsgi.url_scheme'] = 'https'
392         else:
393             environ['wsgi.url_scheme'] = 'http'
394
395         self._sanitizeEnv(environ)
396
397         headers_set = []
398         headers_sent = []
399         result = None
400
401         def write(data):
402             assert type(data) is str, 'write() argument must be string'
403             assert headers_set, 'write() before start_response()'
404
405             if not headers_sent:
406                 status, responseHeaders = headers_sent[:] = headers_set
407                 found = False
408                 for header,value in responseHeaders:
409                     if header.lower() == 'content-length':
410                         found = True
411                         break
412                 if not found and result is not None:
413                     try:
414                         if len(result) == 1:
415                             responseHeaders.append(('Content-Length',
416                                                     str(len(data))))
417                     except:
418                         pass
419                 s = 'Status: %s\r\n' % status
420                 for header in responseHeaders:
421                     s += '%s: %s\r\n' % header
422                 s += '\r\n'
423                 request.stdout.write(s)
424
425             request.stdout.write(data)
426             request.stdout.flush()
427
428         def start_response(status, response_headers, exc_info=None):
429             if exc_info:
430                 try:
431                     if headers_sent:
432                         # Re-raise if too late
433                         raise exc_info[0], exc_info[1], exc_info[2]
434                 finally:
435                     exc_info = None # avoid dangling circular ref
436             else:
437                 assert not headers_set, 'Headers already set!'
438
439             assert type(status) is str, 'Status must be a string'
440             assert len(status) >= 4, 'Status must be at least 4 characters'
441             assert int(status[:3]), 'Status must begin with 3-digit code'
442             assert status[3] == ' ', 'Status must have a space after code'
443             assert type(response_headers) is list, 'Headers must be a list'
444             if __debug__:
445                 for name,val in response_headers:
446                     assert type(name) is str, 'Header name "%s" must be a string' % name
447                     assert type(val) is str, 'Value of header "%s" must be a string' % name
448
449             headers_set[:] = [status, response_headers]
450             return write
451
452         if not self.multithreaded:
453             self._appLock.acquire()
454         try:
455             try:
456                 result = self.application(environ, start_response)
457                 try:
458                     for data in result:
459                         if data:
460                             write(data)
461                     if not headers_sent:
462                         write('') # in case body was empty
463                 finally:
464                     if hasattr(result, 'close'):
465                         result.close()
466             except socket.error, e:
467                 if e[0] != errno.EPIPE:
468                     raise # Don't let EPIPE propagate beyond server
469         finally:
470             if not self.multithreaded:
471                 self._appLock.release()
472
473     def _sanitizeEnv(self, environ):
474         """Fill-in/deduce missing values in environ."""
475         reqUri = None
476         if environ.has_key('REQUEST_URI'):
477             reqUri = environ['REQUEST_URI'].split('?', 1)
478
479         # Ensure QUERY_STRING exists
480         if not environ.has_key('QUERY_STRING') or not environ['QUERY_STRING']:
481             if reqUri is not None and len(reqUri) > 1:
482                 environ['QUERY_STRING'] = reqUri[1]
483             else:
484                 environ['QUERY_STRING'] = ''
485
486         # Check WSGI_SCRIPT_NAME
487         scriptName = environ.get('WSGI_SCRIPT_NAME')
488         if scriptName is None:
489             scriptName = self.scriptName
490         else:
491             warnings.warn('WSGI_SCRIPT_NAME environment variable for scgi '
492                           'servers is deprecated',
493                           DeprecationWarning)
494             if scriptName.lower() == 'none':
495                 scriptName = None
496
497         if scriptName is None:
498             # Do nothing (most likely coming from cgi2scgi)
499             return
500
501         if scriptName is NoDefault:
502             # Pull SCRIPT_NAME/PATH_INFO from environment, with empty defaults
503             if not environ.has_key('SCRIPT_NAME'):
504                 environ['SCRIPT_INFO'] = ''
505             if not environ.has_key('PATH_INFO') or not environ['PATH_INFO']:
506                 if reqUri is not None:
507                     environ['PATH_INFO'] = reqUri[0]
508                 else:
509                     environ['PATH_INFO'] = ''
510         else:
511             # Configured scriptName
512             warnings.warn('Configured SCRIPT_NAME is deprecated\n'
513                           'Do not use WSGI_SCRIPT_NAME or the scriptName\n'
514                           'keyword parameter -- they will be going away',
515                           DeprecationWarning)
516
517             value = environ['SCRIPT_NAME']
518             value += environ.get('PATH_INFO', '')
519             if not value.startswith(scriptName):
520                 self.logger