forked from bloomreach/s4cmd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
s4cmd.py
executable file
·1947 lines (1671 loc) · 68.9 KB
/
s4cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
#
# Copyright 2012-2018 BloomReach, Inc.
# Portions Copyright 2014 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Super S3 command line tool.
"""
import sys, os, re, optparse, multiprocessing, fnmatch, time, hashlib, errno, pytz
import logging, traceback, types, threading, random, socket, shlex, datetime, json
IS_PYTHON2 = sys.version_info[0] == 2
if IS_PYTHON2:
from cStringIO import StringIO
import Queue
import ConfigParser
else:
from io import BytesIO as StringIO
import queue as Queue
import configparser as ConfigParser
def cmp(a, b):
return (a > b) - (a < b)
if sys.version_info < (2, 7):
# Python < 2.7 doesn't have the cmp_to_key function.
from utils import cmp_to_key
else:
from functools import cmp_to_key
##
## Global constants
##
S4CMD_VERSION = "2.1.0"
PATH_SEP = '/'
DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S UTC'
TIMESTAMP_FORMAT = '%04d-%02d-%02d %02d:%02d'
SOCKET_TIMEOUT = 5 * 60 # in sec(s) (timeout if we don't receive any recv() callback)
socket.setdefaulttimeout(SOCKET_TIMEOUT)
# Global list for temp files.
TEMP_FILES = set()
# Environment variable names for S3 credentials.
S3_ACCESS_KEY_NAME = "AWS_ACCESS_KEY_ID"
S3_SECRET_KEY_NAME = "AWS_SECRET_ACCESS_KEY"
S4CMD_ENV_KEY = "S4CMD_OPTS"
##
## Utility classes
##
class Failure(RuntimeError):
'''Exception for runtime failures'''
pass
class InvalidArgument(RuntimeError):
'''Exception for invalid input parameters'''
pass
class RetryFailure(Exception):
'''Runtime failure that can be retried'''
pass
class S4cmdLoggingClass:
def __init__(self):
self.log = logging.Logger("s4cmd")
self.log.stream = sys.stderr
self.log_handler = logging.StreamHandler(self.log.stream)
self.log.addHandler(self.log_handler)
def configure(self, opt):
'Configure the logger based on command-line arguments'''
self.log_handler.setFormatter(logging.Formatter('%(message)s', DATETIME_FORMAT))
if opt.debug:
self.log.verbosity = 3
self.log_handler.setFormatter(logging.Formatter(
' (%(levelname).1s)%(filename)s:%(lineno)-4d %(message)s',
DATETIME_FORMAT))
self.log.setLevel(logging.DEBUG)
elif opt.verbose:
self.log.verbosity = 2
self.log.setLevel(logging.INFO)
else:
self.log.verbosity = 1
self.log.setLevel(logging.ERROR)
def get_loggers(self):
'''Return a list of the logger methods: (debug, info, warn, error)'''
return self.log.debug, self.log.info, self.log.warn, self.log.error
s4cmd_logging = S4cmdLoggingClass()
debug, info, warn, error = s4cmd_logging.get_loggers()
def get_default_thread_count():
return int(os.getenv('S4CMD_NUM_THREADS', multiprocessing.cpu_count() * 4))
def log_calls(func):
'''Decorator to log function calls.'''
def wrapper(*args, **kargs):
callStr = "%s(%s)" % (func.__name__, ", ".join([repr(p) for p in args] + ["%s=%s" % (k, repr(v)) for (k, v) in list(kargs.items())]))
debug(">> %s", callStr)
ret = func(*args, **kargs)
debug("<< %s: %s", callStr, repr(ret))
return ret
return wrapper
##
## Utility functions
##
def synchronized(func):
'''Decorator to synchronize function.'''
func.__lock__ = threading.Lock()
def synced_func(*args, **kargs):
with func.__lock__:
return func(*args, **kargs)
return synced_func
def clear_progress():
'''Clear previous progress message, if any.'''
progress('')
@synchronized
def progress(msg, *args):
'''Show current progress message to stderr.
This function will remember the previous message so that next time,
it will clear the previous message before showing next one.
'''
# Don't show any progress if the output is directed to a file.
if not (sys.stdout.isatty() and sys.stderr.isatty()):
return
text = (msg % args)
if progress.prev_message:
sys.stderr.write(' ' * len(progress.prev_message) + '\r')
sys.stderr.write(text + '\r')
progress.prev_message = text
progress.prev_message = None
@synchronized
def message(msg, *args):
'''Program message output.'''
clear_progress()
text = (msg % args)
sys.stdout.write(text + '\n')
def fail(message, exc_info=None, status=1, stacktrace=False):
'''Utility function to handle runtime failures gracefully.
Show concise information if possible, then terminate program.
'''
text = message
if exc_info:
text += str(exc_info)
error(text)
if stacktrace:
error(traceback.format_exc())
clean_tempfiles()
if __name__ == '__main__':
sys.exit(status)
else:
raise RuntimeError(status)
@synchronized
def tempfile_get(target):
'''Get a temp filename for atomic download.'''
fn = '%s-%s.tmp' % (target, ''.join(random.Random().sample("0123456789abcdefghijklmnopqrstuvwxyz", 15)))
TEMP_FILES.add(fn)
return fn
@synchronized
def tempfile_set(tempfile, target):
'''Atomically rename and clean tempfile'''
if target:
os.rename(tempfile, target)
else:
os.unlink(tempfile)
if target in TEMP_FILES:
TEMP_FILES.remove(tempfile)
def clean_tempfiles():
'''Clean up temp files'''
for fn in TEMP_FILES:
if os.path.exists(fn):
os.unlink(fn)
class S3URL:
'''Simple wrapper for S3 URL.
This class parses a S3 URL and provides accessors to each component.
'''
S3URL_PATTERN = re.compile(r'(s3[n]?)://([^/]+)[/]?(.*)')
def __init__(self, uri):
'''Initialization, parse S3 URL'''
try:
self.proto, self.bucket, self.path = S3URL.S3URL_PATTERN.match(uri).groups()
self.proto = 's3' # normalize s3n => s3
except:
raise InvalidArgument('Invalid S3 URI: %s' % uri)
def __str__(self):
'''Return the original S3 URL'''
return S3URL.combine(self.proto, self.bucket, self.path)
def get_fixed_path(self):
'''Get the fixed part of the path without wildcard'''
pi = self.path.split(PATH_SEP)
fi = []
for p in pi:
if '*' in p or '?' in p:
break
fi.append(p)
return PATH_SEP.join(fi)
@staticmethod
def combine(proto, bucket, path):
'''Combine each component and general a S3 url string, no path normalization
here. The path should not start with slash.
'''
return '%s://%s/%s' % (proto, bucket, path)
@staticmethod
def is_valid(uri):
'''Check if given uri is a valid S3 URL'''
return S3URL.S3URL_PATTERN.match(uri) != None
class BotoClient(object):
'''This is a bridge between s4cmd and boto3 library. All S3 method calls should go through this class.
The white list ALLOWED_CLIENT_METHODS lists those methods that are allowed. Also, EXTRA_CLIENT_PARAMS
is the list of S3 parameters that we can take from command-line argument and pass through to the API.
'''
# Encapsulate boto3 interface intercept all API calls.
boto3 = __import__('boto3') # version >= 1.3.1
botocore = __import__('botocore')
# Exported exceptions.
BotoError = boto3.exceptions.Boto3Error
ClientError = botocore.exceptions.ClientError
NoCredentialsError = botocore.exceptions.NoCredentialsError
# Exceptions that retries may work. May change in the future.
S3RetryableErrors = (
socket.timeout,
socket.error if IS_PYTHON2 else ConnectionError,
botocore.vendored.requests.packages.urllib3.exceptions.ReadTimeoutError,
botocore.exceptions.IncompleteReadError
)
# List of API functions we use in s4cmd.
ALLOWED_CLIENT_METHODS = [
'list_buckets',
'get_paginator',
'head_object',
'put_object',
'create_bucket',
'create_multipart_upload',
'upload_part',
'complete_multipart_upload',
'abort_multipart_upload',
'get_object',
'copy_object',
'delete_object',
'delete_objects',
'upload_part_copy'
]
# List of parameters grabbed from http://boto3.readthedocs.io/en/latest/reference/services/s3.html
# Pass those parameters directly to boto3 low level API. Most of the parameters are not tested.
EXTRA_CLIENT_PARAMS = [
("ACL", "string",
"The canned ACL to apply to the object."),
("CacheControl", "string",
"Specifies caching behavior along the request/reply chain."),
("ContentDisposition", "string",
"Specifies presentational information for the object."),
("ContentEncoding", "string",
"Specifies what content encodings have been applied to the object and thus what decoding mechanisms must be applied to obtain the media-type referenced by the Content-Type header field."),
("ContentLanguage", "string",
"The language the content is in."),
("ContentMD5", "string",
"The base64-encoded 128-bit MD5 digest of the part data."),
("ContentType", "string",
"A standard MIME type describing the format of the object data."),
("CopySourceIfMatch", "string",
"Copies the object if its entity tag (ETag) matches the specified tag."),
("CopySourceIfModifiedSince", "datetime",
"Copies the object if it has been modified since the specified time."),
("CopySourceIfNoneMatch", "string",
"Copies the object if its entity tag (ETag) is different than the specified ETag."),
("CopySourceIfUnmodifiedSince", "datetime",
"Copies the object if it hasn't been modified since the specified time."),
("CopySourceRange", "string",
"The range of bytes to copy from the source object. The range value must use the form bytes=first-last, where the first and last are the zero-based byte offsets to copy. For example, bytes=0-9 indicates that you want to copy the first ten bytes of the source. You can copy a range only if the source object is greater than 5 GB."),
("CopySourceSSECustomerAlgorithm", "string",
"Specifies the algorithm to use when decrypting the source object (e.g., AES256)."),
("CopySourceSSECustomerKeyMD5", "string",
"Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321. Amazon S3 uses this header for a message integrity check to ensure the encryption key was transmitted without error. Please note that this parameter is automatically populated if it is not provided. Including this parameter is not required"),
("CopySourceSSECustomerKey", "string",
"Specifies the customer-provided encryption key for Amazon S3 to use to decrypt the source object. The encryption key provided in this header must be one that was used when the source object was created."),
("ETag", "string",
"Entity tag returned when the part was uploaded."),
("Expires", "datetime",
"The date and time at which the object is no longer cacheable."),
("GrantFullControl", "string",
"Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object."),
("GrantReadACP", "string",
"Allows grantee to read the object ACL."),
("GrantRead", "string",
"Allows grantee to read the object data and its metadata."),
("GrantWriteACP", "string",
"Allows grantee to write the ACL for the applicable object."),
("IfMatch", "string",
"Return the object only if its entity tag (ETag) is the same as the one specified, otherwise return a 412 (precondition failed)."),
("IfModifiedSince", "datetime",
"Return the object only if it has been modified since the specified time, otherwise return a 304 (not modified)."),
("IfNoneMatch", "string",
"Return the object only if its entity tag (ETag) is different from the one specified, otherwise return a 304 (not modified)."),
("IfUnmodifiedSince", "datetime",
"Return the object only if it has not been modified since the specified time, otherwise return a 412 (precondition failed)."),
("Metadata", "dict",
"A map (in json string) of metadata to store with the object in S3"),
("MetadataDirective", "string",
"Specifies whether the metadata is copied from the source object or replaced with metadata provided in the request."),
("MFA", "string",
"The concatenation of the authentication device's serial number, a space, and the value that is displayed on your authentication device."),
("RequestPayer", "string",
"Confirms that the requester knows that she or he will be charged for the request. Bucket owners need not specify this parameter in their requests. Documentation on downloading objects from requester pays buckets can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html"),
("ServerSideEncryption", "string",
"The Server-side encryption algorithm used when storing this object in S3 (e.g., AES256, aws:kms)."),
("SSECustomerAlgorithm", "string",
"Specifies the algorithm to use to when encrypting the object (e.g., AES256)."),
("SSECustomerKeyMD5", "string",
"Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321. Amazon S3 uses this header for a message integrity check to ensure the encryption key was transmitted without error. Please note that this parameter is automatically populated if it is not provided. Including this parameter is not required"),
("SSECustomerKey", "string",
"Specifies the customer-provided encryption key for Amazon S3 to use in encrypting data. This value is used to store the object and then it is discarded; Amazon does not store the encryption key. The key must be appropriate for use with the algorithm specified in the x-amz-server-side-encryption-customer-algorithm header."),
("SSEKMSKeyId", "string",
"Specifies the AWS KMS key ID to use for object encryption. All GET and PUT requests for an object protected by AWS KMS will fail if not made via SSL or using SigV4. Documentation on configuring any of the officially supported AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version"),
("StorageClass", "string",
"The type of storage to use for the object. Defaults to 'STANDARD'."),
("VersionId", "string",
"VersionId used to reference a specific version of the object."),
("WebsiteRedirectLocation", "string",
"If the bucket is configured as a website, redirects requests for this object to another object in the same bucket or to an external URL. Amazon S3 stores the value of this header in the object metadata."),
]
def __init__(self, opt, aws_access_key_id=None, aws_secret_access_key=None):
'''Initialize boto3 API bridge class. Calculate and cache all legal parameters
for each method we are going to call.
'''
self.opt = opt
if (aws_access_key_id is not None) and (aws_secret_access_key is not None):
self.client = self.boto3.client('s3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
endpoint_url=opt.endpoint_url)
else:
self.client = self.boto3.client('s3', endpoint_url=opt.endpoint_url)
# Cache the result so we don't have to recalculate.
self.legal_params = {}
for method in BotoClient.ALLOWED_CLIENT_METHODS:
self.legal_params[method] = self.get_legal_params(method)
def __getattribute__(self, method):
'''Intercept boto3 API call to inject our extra options.'''
if method in BotoClient.ALLOWED_CLIENT_METHODS:
def wrapped_method(*args, **kargs):
merged_kargs = self.merge_opt_params(method, kargs)
callStr = "%s(%s)" % ("S3APICALL " + method, ", ".join([repr(p) for p in args] + ["%s=%s" % (k, repr(v)) for (k, v) in list(kargs.items())]))
debug(">> %s", callStr)
ret = getattr(self.client, method)(*args, **merged_kargs)
debug("<< %s: %s", callStr, repr(ret))
return ret
return wrapped_method
return super(BotoClient, self).__getattribute__(method)
def get_legal_params(self, method):
'''Given a API name, list all legal parameters using boto3 service model.'''
if method not in self.client.meta.method_to_api_mapping:
# Injected methods. Ignore.
return []
api = self.client.meta.method_to_api_mapping[method]
shape = self.client.meta.service_model.operation_model(api).input_shape
if shape is None:
# No params needed for this API.
return []
return shape.members.keys()
def merge_opt_params(self, method, kargs):
'''Combine existing parameters with extra options supplied from command line
options. Carefully merge special type of parameter if needed.
'''
for key in self.legal_params[method]:
if not hasattr(self.opt, key) or getattr(self.opt, key) is None:
continue
if key in kargs and type(kargs[key]) == dict:
assert(type(getattr(self.opt, key)) == dict)
# Merge two dictionaries.
for k, v in getattr(self.opt, key).iteritems():
kargs[key][k] = v
else:
# Overwrite values.
kargs[key] = getattr(self.opt, key)
return kargs
@staticmethod
def add_options(parser):
'''Add the whole list of API parameters into optparse.'''
for param, param_type, param_doc in BotoClient.EXTRA_CLIENT_PARAMS:
parser.add_option('--API-' + param, help=param_doc, type=param_type, dest=param)
def close(self):
'''Close this client.'''
self.client = None
class TaskQueue(Queue.Queue):
'''Wrapper class to Queue.
Since we need to ensure that main thread is not blocked by child threads
and cannot be wake up by Ctrl-C interrupt, we have to override join()
method.
'''
def __init__(self):
Queue.Queue.__init__(self)
self.exc_info = None
def join(self):
'''Override original join() with a timeout and handle keyboard interrupt.'''
self.all_tasks_done.acquire()
try:
while self.unfinished_tasks:
self.all_tasks_done.wait(1000)
# Child thread has exceptions, fail main thread too.
if self.exc_info:
fail('[Thread Failure] ', exc_info=self.exc_info)
except KeyboardInterrupt:
raise Failure('Interrupted by user')
finally:
self.all_tasks_done.release()
def terminate(self, exc_info=None):
'''Terminate all threads by deleting the queue and forcing the child threads
to quit.
'''
if exc_info:
self.exc_info = exc_info
try:
while self.get_nowait():
self.task_done()
except Queue.Empty:
pass
class ThreadPool(object):
'''Utility class for thread pool.
This class needs to work with a utility class, which is derived from Worker.
'''
class Worker(threading.Thread):
'''Utility thread worker class.
This class handles all items in task queue and execute them. It also
handles runtime errors gracefully, and provides automatic retry.
'''
def __init__(self, pool):
'''Thread worker initalization.
Setup values and start threads right away.
'''
threading.Thread.__init__(self)
self.pool = pool
self.opt = pool.opt
self.daemon = True
self.start()
def run(self):
'''Main thread worker execution.
This function extract items from task queue and execute them accordingly.
It will retry tasks when encounter exceptions by putting the same item
back to the work queue.
'''
while True:
item = self.pool.tasks.get()
if not item:
break
try:
func_name, retry, args, kargs = item
self.__class__.__dict__[func_name](self, *args, **kargs)
except InvalidArgument as e:
self.pool.tasks.terminate(e)
fail('[Invalid Argument] ', exc_info=e)
except Failure as e:
self.pool.tasks.terminate(e)
fail('[Runtime Failure] ', exc_info=e)
except OSError as e:
self.pool.tasks.terminate(e)
fail('[OSError] %d: %s' % (e.errno, e.strerror))
except BotoClient.S3RetryableErrors as e:
if retry >= self.opt.retry:
self.pool.tasks.terminate(e)
fail('[Runtime Exception] ', exc_info=e, stacktrace=True)
else:
# Show content of exceptions.
error(e)
time.sleep(self.opt.retry_delay)
self.pool.tasks.put((func_name, retry + 1, args, kargs))
except Exception as e:
self.pool.tasks.terminate(e)
fail('[Exception] ', exc_info=e)
finally:
self.pool.processed()
self.pool.tasks.task_done()
def __init__(self, thread_class, opt):
'''Constructor of ThreadPool.
Create workers and pool will automatically inherit all methods from
thread_class by redirecting calls through __getattribute__().
'''
self.opt = opt
self.tasks = TaskQueue()
self.processed_tasks = 0
self.thread_class = thread_class
self.workers = []
for i in range(opt.num_threads):
self.workers.append(thread_class(self))
def __enter__(self):
'''Utility function for with statement'''
return self
def __exit__(self, exc_type, exc_value, traceback):
'''Utility function for with statement, wait for completion'''
self.join()
return isinstance(exc_value, TypeError)
def __getattribute__(self, name):
'''Special attribute accessor to add tasks into task queue.
Here if we found a function not in ThreadPool, we will try
to see if we have a function in the utility class. If so, we
add the function call into task queue.
'''
try:
attr = super(ThreadPool, self).__getattribute__(name)
except AttributeError as e:
if name in self.thread_class.__dict__:
# Here we masquerade the original function with add_task(). So the
# function call will be put into task queue.
def deferred_task(*args, **kargs):
self.add_task(name, *args, **kargs)
attr = deferred_task
else:
raise AttributeError('Unable to resolve %s' % name)
return attr
def add_task(self, func_name, *args, **kargs):
'''Utility function to add a single task into task queue'''
self.tasks.put((func_name, 0, args, kargs))
def join(self):
'''Utility function to wait all tasks to complete'''
self.tasks.join()
# Force each thread to break loop.
for worker in self.workers:
self.tasks.put(None)
# Wait for all thread to terminate.
for worker in self.workers:
worker.join()
worker.s3 = None
@synchronized
def processed(self):
'''Increase the processed task counter and show progress message'''
self.processed_tasks += 1
qsize = self.tasks.qsize()
if qsize > 0:
progress('[%d task(s) completed, %d remaining, %d thread(s)]', self.processed_tasks, qsize, len(self.workers))
else:
progress('[%d task(s) completed, %d thread(s)]', self.processed_tasks, len(self.workers))
class S3Handler(object):
'''Core S3 class.
This class provide the functions for all operations. It will start thread
pool to execute tasks generated by each operation. See ThreadUtil for
more details about the tasks.
'''
S3_KEYS = None
@staticmethod
def s3_keys_from_env():
'''Retrieve S3 access keys from the environment, or None if not present.'''
env = os.environ
if S3_ACCESS_KEY_NAME in env and S3_SECRET_KEY_NAME in env:
keys = (env[S3_ACCESS_KEY_NAME], env[S3_SECRET_KEY_NAME])
debug("read S3 keys from environment")
return keys
else:
return None
@staticmethod
def s3_keys_from_cmdline(opt):
'''Retrieve S3 access keys from the command line, or None if not present.'''
if opt.access_key != None and opt.secret_key != None:
keys = (opt.access_key, opt.secret_key)
debug("read S3 keys from commandline")
return keys
else:
return None
@staticmethod
def s3_keys_from_s3cfg(opt):
'''Retrieve S3 access key settings from s3cmd's config file, if present; otherwise return None.'''
try:
if opt.s3cfg != None:
s3cfg_path = "%s" % opt.s3cfg
else:
s3cfg_path = "%s/.s3cfg" % os.environ["HOME"]
if not os.path.exists(s3cfg_path):
return None
config = ConfigParser.ConfigParser()
config.read(s3cfg_path)
keys = config.get("default", "access_key"), config.get("default", "secret_key")
debug("read S3 keys from %s file", s3cfg_path)
return keys
except Exception as e:
info("could not read S3 keys from %s file; skipping (%s)", s3cfg_path, e)
return None
@staticmethod
def init_s3_keys(opt):
'''Initialize s3 access keys from environment variable or s3cfg config file.'''
S3Handler.S3_KEYS = S3Handler.s3_keys_from_cmdline(opt) or S3Handler.s3_keys_from_env() \
or S3Handler.s3_keys_from_s3cfg(opt)
def __init__(self, opt):
'''Constructor, connect to S3 store'''
self.s3 = None
self.opt = opt
self.connect()
def __del__(self):
'''Destructor, stop s3 connection'''
self.s3 = None
def connect(self):
'''Connect to S3 storage'''
try:
if S3Handler.S3_KEYS:
self.s3 = BotoClient(self.opt, S3Handler.S3_KEYS[0], S3Handler.S3_KEYS[1])
else:
self.s3 = BotoClient(self.opt)
except Exception as e:
raise RetryFailure('Unable to connect to s3: %s' % e)
@log_calls
def list_buckets(self):
'''List all buckets'''
result = []
for bucket in self.s3.list_buckets().get('Buckets') or []:
result.append({
'name': S3URL.combine('s3', bucket['Name'], ''),
'is_dir': True,
'size': 0,
'last_modified': bucket['CreationDate']
})
return result
@log_calls
def s3walk(self, basedir, show_dir=None):
'''Walk through a S3 directory. This function initiate a walk with a basedir.
It also supports multiple wildcards.
'''
# Provide the default value from command line if no override.
if not show_dir:
show_dir = self.opt.show_dir
# trailing slash normalization, this is for the reason that we want
# ls 's3://foo/bar/' has the same result as 's3://foo/bar'. Since we
# call partial_match() to check wildcards, we need to ensure the number
# of slashes stays the same when we do this.
if basedir[-1] == PATH_SEP:
basedir = basedir[0:-1]
s3url = S3URL(basedir)
result = []
pool = ThreadPool(ThreadUtil, self.opt)
pool.s3walk(s3url, s3url.get_fixed_path(), s3url.path, result)
pool.join()
# automatic directory detection
if not show_dir and len(result) == 1 and result[0]['is_dir']:
path = result[0]['name']
s3url = S3URL(path)
result = []
pool = ThreadPool(ThreadUtil, self.opt)
pool.s3walk(s3url, s3url.get_fixed_path(), s3url.path, result)
pool.join()
def compare(x, y):
'''Comparator for ls output'''
result = -cmp(x['is_dir'], y['is_dir'])
if result != 0:
return result
return cmp(x['name'], y['name'])
return sorted(result, key=cmp_to_key(compare))
@log_calls
def local_walk(self, basedir):
'''Walk through local directories from root basedir'''
result = []
for root, dirs, files in os.walk(basedir):
for f in files:
result.append(os.path.join(root, f))
return result
@log_calls
def get_basename(self, path):
'''Unix style basename.
This fuction will return 'bar' for '/foo/bar/' instead of empty string.
It is used to normalize the input trailing slash.
'''
if path[-1] == PATH_SEP:
path = path[0:-1]
return os.path.basename(path)
def source_expand(self, source):
'''Expand the wildcards for an S3 path. This emulates the shall expansion
for wildcards if the input is local path.
'''
result = []
if not isinstance(source, list):
source = [source]
for src in source:
# XXX Hacky: We need to disable recursive when we expand the input
# parameters, need to pass this as an override parameter if
# provided.
tmp = self.opt.recursive
self.opt.recursive = False
result += [f['name'] for f in self.s3walk(src, True)]
self.opt.recursive = tmp
if (len(result) == 0) and (not self.opt.ignore_empty_source):
fail("[Runtime Failure] Source doesn't exist.")
return result
@log_calls
def put_single_file(self, pool, source, target):
'''Upload a single file or a directory by adding a task into queue'''
if os.path.isdir(source):
if self.opt.recursive:
for f in (f for f in self.local_walk(source) if not os.path.isdir(f)):
target_url = S3URL(target)
# deal with ./ or ../ here by normalizing the path.
joined_path = os.path.normpath(os.path.join(target_url.path, os.path.relpath(f, source)))
pool.upload(f, S3URL.combine('s3', target_url.bucket, joined_path))
else:
message('omitting directory "%s".' % source)
else:
pool.upload(source, target)
@log_calls
def put_files(self, source, target):
'''Upload files to S3.
This function can handle multiple file upload if source is a list.
Also, it works for recursive mode which copy all files and keep the
directory structure under the given source directory.
'''
pool = ThreadPool(ThreadUtil, self.opt)
if not isinstance(source, list):
source = [source]
if target[-1] == PATH_SEP:
for src in source:
self.put_single_file(pool, src, os.path.join(target, self.get_basename(src)))
else:
if len(source) == 1:
self.put_single_file(pool, source[0], target)
else:
raise Failure('Target "%s" is not a directory (with a trailing slash).' % target)
pool.join()
@log_calls
def create_bucket(self, source):
'''Use the create_bucket API to create a new bucket'''
s3url = S3URL(source)
message('Creating %s', source)
if not self.opt.dry_run:
resp = self.s3.create_bucket(Bucket=s3url.bucket)
if resp['ResponseMetadata']["HTTPStatusCode"] == 200:
message('Done.')
else:
raise Failure('Unable to create bucket %s' % source)
@log_calls
def update_privilege(self, obj, target):
'''Get privileges from metadata of the source in s3, and apply them to target'''
if 'privilege' in obj['Metadata']:
os.chmod(target, int(obj['Metadata']['privilege'], 8))
@log_calls
def print_files(self, source):
'''Print out a series of files'''
sources = self.source_expand(source)
for source in sources:
s3url = S3URL(source)
response = self.s3.get_object(Bucket=s3url.bucket, Key=s3url.path)
message('%s', response['Body'].read())
@log_calls
def get_single_file(self, pool, source, target):
'''Download a single file or a directory by adding a task into queue'''
if source[-1] == PATH_SEP:
if self.opt.recursive:
basepath = S3URL(source).path
for f in (f for f in self.s3walk(source) if not f['is_dir']):
pool.download(f['name'], os.path.join(target, os.path.relpath(S3URL(f['name']).path, basepath)))
else:
message('omitting directory "%s".' % source)
else:
pool.download(source, target)
@log_calls
def get_files(self, source, target):
'''Download files.
This function can handle multiple files if source S3 URL has wildcard
characters. It also handles recursive mode by download all files and
keep the directory structure.
'''
pool = ThreadPool(ThreadUtil, self.opt)
source = self.source_expand(source)
if os.path.isdir(target):
for src in source:
self.get_single_file(pool, src, os.path.join(target, self.get_basename(S3URL(src).path)))
else:
if len(source) > 1:
raise Failure('Target "%s" is not a directory.' % target)
# Get file if it exists on s3 otherwise do nothing
elif len(source) == 1:
self.get_single_file(pool, source[0], target)
else:
#Source expand may return empty list only if ignore-empty-source is set to true
pass
pool.join()
@log_calls
def delete_removed_files(self, source, target):
'''Remove remote files that are not present in the local source.
(Obsolete) It is used for old sync command now.
'''
message("Deleting files found in %s and not in %s", source, target)
if os.path.isdir(source):
unecessary = []
basepath = S3URL(target).path
for f in [f for f in self.s3walk(target) if not f['is_dir']]:
local_name = os.path.join(source, os.path.relpath(S3URL(f['name']).path, basepath))
if not os.path.isfile(local_name):
message("%s not found locally, adding to delete queue", local_name)
unecessary.append(f['name'])
if len(unecessary) > 0:
pool = ThreadPool(ThreadUtil, self.opt)
for del_file in unecessary:
pool.delete(del_file)
pool.join()
else:
raise Failure('Source "%s" is not a directory.' % target)
@log_calls
def cp_single_file(self, pool, source, target, delete_source):
'''Copy a single file or a directory by adding a task into queue'''
if source[-1] == PATH_SEP:
if self.opt.recursive:
basepath = S3URL(source).path
for f in (f for f in self.s3walk(source) if not f['is_dir']):
pool.copy(f['name'], os.path.join(target, os.path.relpath(S3URL(f['name']).path, basepath)), delete_source=delete_source)
else:
message('omitting directory "%s".' % source)
else:
pool.copy(source, target, delete_source=delete_source)
@log_calls
def cp_files(self, source, target, delete_source=False):
'''Copy files
This function can handle multiple files if source S3 URL has wildcard
characters. It also handles recursive mode by copying all files and
keep the directory structure.
'''
pool = ThreadPool(ThreadUtil, self.opt)
source = self.source_expand(source)
if target[-1] == PATH_SEP:
for src in source:
self.cp_single_file(pool, src, os.path.join(target, self.get_basename(S3URL(src).path)), delete_source)
else:
if len(source) > 1:
raise Failure('Target "%s" is not a directory (with a trailing slash).' % target)
# Copy file if it exists otherwise do nothing
elif len(source) == 1:
self.cp_single_file(pool, source[0], target, delete_source)
else:
# Source expand may return empty list only if ignore-empty-source is set to true
pass
pool.join()
@log_calls
def del_files(self, source):
'''Delete files on S3'''
src_files = []
for obj in self.s3walk(source):
if not obj['is_dir']: # ignore directories
src_files.append(obj['name'])
pool = ThreadPool(ThreadUtil, self.opt)
pool.batch_delete(src_files)
pool.join()
@log_calls
def relative_dir_walk(self, dir):
'''Generic version of directory walk. Return file list without base path
for comparison.
'''
result = []
if S3URL.is_valid(dir):
basepath = S3URL(dir).path
for f in (f for f in self.s3walk(dir) if not f['is_dir']):
result.append(os.path.relpath(S3URL(f['name']).path, basepath))
else:
for f in (f for f in self.local_walk(dir) if not os.path.isdir(f)):
result.append(os.path.relpath(f, dir))
return result
@log_calls
def dsync_files(self, source, target):
'''Sync directory to directory.'''
src_s3_url = S3URL.is_valid(source)
dst_s3_url = S3URL.is_valid(target)
source_list = self.relative_dir_walk(source)
if len(source_list) == 0 or '.' in source_list:
raise Failure('Sync command need to sync directory to directory.')
sync_list = [(os.path.join(source, f), os.path.join(target, f)) for f in source_list]
pool = ThreadPool(ThreadUtil, self.opt)
if src_s3_url and not dst_s3_url:
for src, dest in sync_list:
pool.download(src, dest)
elif not src_s3_url and dst_s3_url:
for src, dest in sync_list:
pool.upload(src, dest)
elif src_s3_url and dst_s3_url:
for src, dest in sync_list: