diff --git a/tableaudocumentapi/workbook.py b/tableaudocumentapi/workbook.py index 889f746..0da1827 100644 --- a/tableaudocumentapi/workbook.py +++ b/tableaudocumentapi/workbook.py @@ -3,10 +3,57 @@ # Workbook - A class for writing Tableau workbook files # ############################################################################### +import contextlib import os +import shutil +import tempfile +import zipfile + import xml.etree.ElementTree as ET + from tableaudocumentapi import Datasource +########################################################################### +# +# Utility Functions +# +########################################################################### + + +@contextlib.contextmanager +def temporary_directory(*args, **kwargs): + d = tempfile.mkdtemp(*args, **kwargs) + try: + yield d + finally: + shutil.rmtree(d) + + +def find_twb_in_zip(zip): + for filename in zip.namelist(): + if os.path.splitext(filename)[-1].lower() == '.twb': + return filename + + +def get_twb_xml_from_twbx(filename): + with temporary_directory() as temp: + with zipfile.ZipFile(filename) as zf: + zf.extractall(temp) + twb_file = find_twb_in_zip(zf) + twb_xml = ET.parse(os.path.join(temp, twb_file)) + + return twb_xml + + +def build_twbx_file(twbx_contents, zip): + for root_dir, _, files in os.walk(twbx_contents): + relative_dir = os.path.relpath(root_dir, twbx_contents) + for f in files: + temp_file_full_path = os.path.join( + twbx_contents, relative_dir, f) + zipname = os.path.join(relative_dir, f) + zip.write(temp_file_full_path, arcname=zipname) + class Workbook(object): """ @@ -24,30 +71,18 @@ def __init__(self, filename): Constructor. """ - # We have a valid type of input file - if self._is_valid_file(filename): - # set our filename, open .twb, initialize things - self._filename = filename - self._workbookTree = ET.parse(filename) - self._workbookRoot = self._workbookTree.getroot() - - # prepare our datasource objects - self._datasources = self._prepare_datasources( - self._workbookRoot) # self.workbookRoot.find('datasources') - else: - print('Invalid file type. Must be .twb or .tds.') - raise Exception() - - @classmethod - def from_file(cls, filename): - "Initialize datasource from file (.tds)" - if self._is_valid_file(filename): - self._filename = filename - dsxml = ET.parse(filename).getroot() - return cls(dsxml) + self._filename = filename + + # Determine if this is a twb or twbx and get the xml root + if zipfile.is_zipfile(self._filename): + self._workbookTree = get_twb_xml_from_twbx(self._filename) else: - print('Invalid file type. Must be .twb or .tds.') - raise Exception() + self._workbookTree = ET.parse(self._filename) + + self._workbookRoot = self._workbookTree.getroot() + # prepare our datasource objects + self._datasources = self._prepare_datasources( + self._workbookRoot) # self.workbookRoot.find('datasources') ########### # datasources @@ -76,7 +111,12 @@ def save(self): """ # save the file - self._workbookTree.write(self._filename, encoding="utf-8", xml_declaration=True) + + if zipfile.is_zipfile(self._filename): + self._save_into_twbx(self._filename) + else: + self._workbookTree.write( + self._filename, encoding="utf-8", xml_declaration=True) def save_as(self, new_filename): """ @@ -90,7 +130,11 @@ def save_as(self, new_filename): """ - self._workbookTree.write(new_filename, encoding="utf-8", xml_declaration=True) + if zipfile.is_zipfile(self._filename): + self._save_into_twbx(new_filename) + else: + self._workbookTree.write( + new_filename, encoding="utf-8", xml_declaration=True) ########################################################################### # @@ -107,6 +151,29 @@ def _prepare_datasources(self, xmlRoot): return datasources + def _save_into_twbx(self, filename=None): + # Save reuses existing filename, 'save as' takes a new one + if filename is None: + filename = self._filename + + # Saving a twbx means extracting the contents into a temp folder, + # saving the changes over the twb in that folder, and then + # packaging it back up into a specifically formatted zip with the correct + # relative file paths + + # Extract to temp directory + with temporary_directory() as temp_path: + with zipfile.ZipFile(self._filename) as zf: + twb_file = find_twb_in_zip(zf) + zf.extractall(temp_path) + # Write the new version of the twb to the temp directory + self._workbookTree.write(os.path.join( + temp_path, twb_file), encoding="utf-8", xml_declaration=True) + + # Write the new twbx with the contents of the temp folder + with zipfile.ZipFile(filename, "w", compression=zipfile.ZIP_DEFLATED) as new_twbx: + build_twbx_file(temp_path, new_twbx) + @staticmethod def _is_valid_file(filename): fileExtension = os.path.splitext(filename)[-1].lower() diff --git a/test/assets/CONNECTION.xml b/test/assets/CONNECTION.xml new file mode 100644 index 0000000..392d112 --- /dev/null +++ b/test/assets/CONNECTION.xml @@ -0,0 +1 @@ + diff --git a/test/assets/TABLEAU_10_TDS.tds b/test/assets/TABLEAU_10_TDS.tds new file mode 100644 index 0000000..7a81784 --- /dev/null +++ b/test/assets/TABLEAU_10_TDS.tds @@ -0,0 +1 @@ + diff --git a/test/assets/TABLEAU_10_TWB.twb b/test/assets/TABLEAU_10_TWB.twb new file mode 100644 index 0000000..c116bdf --- /dev/null +++ b/test/assets/TABLEAU_10_TWB.twb @@ -0,0 +1 @@ + diff --git a/test/assets/TABLEAU_10_TWBX.twbx b/test/assets/TABLEAU_10_TWBX.twbx new file mode 100644 index 0000000..ef8f910 Binary files /dev/null and b/test/assets/TABLEAU_10_TWBX.twbx differ diff --git a/test/assets/TABLEAU_93_TDS.tds b/test/assets/TABLEAU_93_TDS.tds new file mode 100644 index 0000000..2afa3ea --- /dev/null +++ b/test/assets/TABLEAU_93_TDS.tds @@ -0,0 +1 @@ + diff --git a/test/assets/TABLEAU_93_TWB.twb b/test/assets/TABLEAU_93_TWB.twb new file mode 100644 index 0000000..cdb6484 --- /dev/null +++ b/test/assets/TABLEAU_93_TWB.twb @@ -0,0 +1 @@ + diff --git a/test/bvt.py b/test/bvt.py index f521465..aa4a247 100644 --- a/test/bvt.py +++ b/test/bvt.py @@ -1,23 +1,23 @@ -import unittest -import io import os +import unittest + import xml.etree.ElementTree as ET from tableaudocumentapi import Workbook, Datasource, Connection, ConnectionParser -# Disable the 120 line limit because of the embedded XML on these lines -# TODO: Move the XML into external files and load them when needed +TEST_DIR = os.path.dirname(__file__) + +TABLEAU_93_TWB = os.path.join(TEST_DIR, 'assets', 'TABLEAU_93_TWB.twb') -TABLEAU_93_WORKBOOK = '''''' # noqa +TABLEAU_93_TDS = os.path.join(TEST_DIR, 'assets', 'TABLEAU_93_TDS.tds') -TABLEAU_93_TDS = '''''' # noqa +TABLEAU_10_TDS = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TDS.tds') -TABLEAU_10_TDS = '''''' # noqa +TABLEAU_10_TWB = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWB.twb') -TABLEAU_10_WORKBOOK = '''''' # noqa +TABLEAU_CONNECTION_XML = ET.parse(os.path.join(TEST_DIR, 'assets', 'CONNECTION.xml')).getroot() -TABLEAU_CONNECTION_XML = ET.fromstring( - '''''') # noqa +TABLEAU_10_TWBX = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWBX.twbx') class HelperMethodTests(unittest.TestCase): @@ -36,14 +36,14 @@ def test_is_valid_file_with_invalid_inputs(self): class ConnectionParserTests(unittest.TestCase): def test_can_extract_legacy_connection(self): - parser = ConnectionParser(ET.fromstring(TABLEAU_93_TDS), '9.2') + parser = ConnectionParser(ET.parse(TABLEAU_93_TDS), '9.2') connections = parser.get_connections() self.assertIsInstance(connections, list) self.assertIsInstance(connections[0], Connection) self.assertEqual(connections[0].dbname, 'TestV1') def test_can_extract_federated_connections(self): - parser = ConnectionParser(ET.fromstring(TABLEAU_10_TDS), '10.0') + parser = ConnectionParser(ET.parse(TABLEAU_10_TDS), '10.0') connections = parser.get_connections() self.assertIsInstance(connections, list) self.assertIsInstance(connections[0], Connection) @@ -76,9 +76,9 @@ def test_can_write_attributes_to_connection(self): class DatasourceModelTests(unittest.TestCase): def setUp(self): - self.tds_file = io.FileIO('test.tds', 'w') - self.tds_file.write(TABLEAU_93_TDS.encode('utf8')) - self.tds_file.seek(0) + with open(TABLEAU_93_TDS, 'rb') as in_file, open('test.tds', 'wb') as out_file: + out_file.write(in_file.read()) + self.tds_file = out_file def tearDown(self): self.tds_file.close() @@ -117,9 +117,9 @@ def test_save_has_xml_declaration(self): class DatasourceModelV10Tests(unittest.TestCase): def setUp(self): - self.tds_file = io.FileIO('test10.tds', 'w') - self.tds_file.write(TABLEAU_10_TDS.encode('utf8')) - self.tds_file.seek(0) + with open(TABLEAU_10_TDS, 'rb') as in_file, open('test.twb', 'wb') as out_file: + out_file.write(in_file.read()) + self.tds_file = out_file def tearDown(self): self.tds_file.close() @@ -147,9 +147,9 @@ def test_can_save_tds(self): class WorkbookModelTests(unittest.TestCase): def setUp(self): - self.workbook_file = io.FileIO('test.twb', 'w') - self.workbook_file.write(TABLEAU_93_WORKBOOK.encode('utf8')) - self.workbook_file.seek(0) + with open(TABLEAU_93_TWB, 'rb') as in_file, open('test.twb', 'wb') as out_file: + out_file.write(in_file.read()) + self.workbook_file = out_file def tearDown(self): self.workbook_file.close() @@ -175,9 +175,9 @@ def test_can_update_datasource_connection_and_save(self): class WorkbookModelV10Tests(unittest.TestCase): def setUp(self): - self.workbook_file = io.FileIO('testv10.twb', 'w') - self.workbook_file.write(TABLEAU_10_WORKBOOK.encode('utf8')) - self.workbook_file.seek(0) + with open(TABLEAU_10_TWB, 'rb') as in_file, open('test.twb', 'wb') as out_file: + out_file.write(in_file.read()) + self.workbook_file = out_file def tearDown(self): self.workbook_file.close() @@ -213,5 +213,43 @@ def test_save_has_xml_declaration(self): self.assertEqual( first_line, "") + +class WorkbookModelV10TWBXTests(unittest.TestCase): + + def setUp(self): + with open(TABLEAU_10_TWBX, 'rb') as in_file, open('test.twbx', 'wb') as out_file: + out_file.write(in_file.read()) + self.workbook_file = out_file + + def tearDown(self): + self.workbook_file.close() + os.unlink(self.workbook_file.name) + + def test_can_open_twbx(self): + wb = Workbook(self.workbook_file.name) + self.assertTrue(wb.datasources) + self.assertTrue(wb.datasources[0].connections) + + def test_can_open_twbx_and_save_changes(self): + original_wb = Workbook(self.workbook_file.name) + original_wb.datasources[0].connections[0].server = 'newdb.test.tsi.lan' + original_wb.save() + + new_wb = Workbook(self.workbook_file.name) + self.assertEqual(new_wb.datasources[0].connections[ + 0].server, 'newdb.test.tsi.lan') + + def test_can_open_twbx_and_save_as_changes(self): + new_twbx_filename = self.workbook_file.name + "_TEST_SAVE_AS" + original_wb = Workbook(self.workbook_file.name) + original_wb.datasources[0].connections[0].server = 'newdb.test.tsi.lan' + original_wb.save_as(new_twbx_filename) + + new_wb = Workbook(new_twbx_filename) + self.assertEqual(new_wb.datasources[0].connections[ + 0].server, 'newdb.test.tsi.lan') + + os.unlink(new_twbx_filename) + if __name__ == '__main__': unittest.main() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy