|  | @@ -18,10 +18,15 @@ class ConnectionClosed(Error):
 | 
	
		
			
				|  |  |      """Connection closed by remote host"""
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class PathNotAllowed(Error):
 | 
	
		
			
				|  |  | +    """Repository path not allowed"""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class RepositoryServer(object):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __init__(self):
 | 
	
		
			
				|  |  | +    def __init__(self, restrict_to_paths):
 | 
	
		
			
				|  |  |          self.repository = None
 | 
	
		
			
				|  |  | +        self.restrict_to_paths = restrict_to_paths
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def serve(self):
 | 
	
		
			
				|  |  |          # Make stdin non-blocking
 | 
	
	
		
			
				|  | @@ -61,11 +66,19 @@ class RepositoryServer(object):
 | 
	
		
			
				|  |  |          path = os.fsdecode(path)
 | 
	
		
			
				|  |  |          if path.startswith('/~'):
 | 
	
		
			
				|  |  |              path = path[1:]
 | 
	
		
			
				|  |  | -        self.repository = Repository(os.path.expanduser(path), create)
 | 
	
		
			
				|  |  | +        path = os.path.realpath(os.path.expanduser(path))
 | 
	
		
			
				|  |  | +        if self.restrict_to_paths:
 | 
	
		
			
				|  |  | +            for restrict_to_path in self.restrict_to_paths:
 | 
	
		
			
				|  |  | +                if path.startswith(os.path.realpath(restrict_to_path)):
 | 
	
		
			
				|  |  | +                    break
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                raise PathNotAllowed(path)
 | 
	
		
			
				|  |  | +        self.repository = Repository(path, create)
 | 
	
		
			
				|  |  |          return self.repository.id
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class RemoteRepository(object):
 | 
	
		
			
				|  |  | +    extra_test_args = []
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      class RPCError(Exception):
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -83,7 +96,7 @@ class RemoteRepository(object):
 | 
	
		
			
				|  |  |          self.unpacker = msgpack.Unpacker(use_list=False)
 | 
	
		
			
				|  |  |          self.p = None
 | 
	
		
			
				|  |  |          if location.host == '__testsuite__':
 | 
	
		
			
				|  |  | -            args = [sys.executable, '-m', 'attic.archiver', 'serve']
 | 
	
		
			
				|  |  | +            args = [sys.executable, '-m', 'attic.archiver', 'serve'] + self.extra_test_args
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              args = ['ssh']
 | 
	
		
			
				|  |  |              if location.port:
 | 
	
	
		
			
				|  | @@ -139,6 +152,8 @@ class RemoteRepository(object):
 | 
	
		
			
				|  |  |                              raise Repository.CheckNeeded(self.location.orig)
 | 
	
		
			
				|  |  |                          elif error == b'IntegrityError':
 | 
	
		
			
				|  |  |                              raise IntegrityError(res)
 | 
	
		
			
				|  |  | +                        elif error == b'PathNotAllowed':
 | 
	
		
			
				|  |  | +                            raise PathNotAllowed(*res)
 | 
	
		
			
				|  |  |                          raise self.RPCError(error)
 | 
	
		
			
				|  |  |                      else:
 | 
	
		
			
				|  |  |                          yield res
 |