diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000..e69de29bb2 diff --git a/404.html b/404.html new file mode 100644 index 0000000000..e5e9cfac53 --- /dev/null +++ b/404.html @@ -0,0 +1,1353 @@ + + + + + + + + + + + + + + + + + + + + + + Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ +

404 - Not found

+ +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/CONTRIBUTING/index.html b/CONTRIBUTING/index.html new file mode 100644 index 0000000000..e145d1210c --- /dev/null +++ b/CONTRIBUTING/index.html @@ -0,0 +1,1534 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + Contributing Guide - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + + + +

Contributing Guide

+

First-time Contributors

+

If this is your first contribution to open source, you can follow this tutorial or check this video series to learn about the contribution workflow with GitHub.

+

We always have tickets labeled ‘good first issue’ and ‘help wanted’. These are a great starting point if you want to contribute. Don’t hesitate to ask questions in the Amundsen Slack channel about the issue if you are not sure about the strategy to follow.

+

Reporting an Issue

+

The easiest way you can contribute to Amundsen is by creating issues. First, search the issues section of the Amundsen repository in case a similar bug or feature request already exists. If you don’t find it, submit your bug, question, proposal or feature request. They will remain closed until sufficient interest, e.g. 👍 reactions, has been shown by the community.

+

Creating Pull Requests

+

Before sending a pull request with significant changes, please use the issue tracker to discuss the potential improvements you want to make. This can help us send you to a solution, a workaround, or an RFC (request for comments item in our RFC repo).

+

Requesting a Feature

+

We have created a community roadmap where you can vote on plans for next releases. However, we are open to hearing your ideas for new features!

+

For that, you can create an issue and select the “Feature Proposal” template. Fill in as much information as possible, and if you can, add responses to the following questions:

+
    +
  • Will we need to add a new model or change any existing model?
  • +
  • What would the migration plan look like? Will it be backwards-compatible?
  • +
  • Which alternatives did you consider?
  • +
+

Setup

+

To start contributing to Amundsen, you need to set up your machine to develop with the project. For that, we have prepared a developer guide that will guide you to set up your environment to develop locally with Amundsen.

+

Next Steps

+

Once you have your environment set and ready to go, you can check our documentation and the project’s community roadmap to see what’s coming.

+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..a1c70dc855 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/architecture/index.html b/architecture/index.html new file mode 100644 index 0000000000..ee59f120f8 --- /dev/null +++ b/architecture/index.html @@ -0,0 +1,1504 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + Architecture - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + + + +

Architecture

+

The following diagram shows the overall architecture for Amundsen. +

+

Frontend

+

The frontend service serves as web UI portal for users interaction. +It is Flask-based web app which representation layer is built with React with Redux, Bootstrap, Webpack, and Babel.

+ +

The search service proxy leverages Elasticsearch’s search functionality (or Apache Atlas’s search API, if that’s the backend you picked) and +provides a RESTful API to serve search requests from the frontend service. This API is documented and live explorable through OpenAPI aka “Swagger”. +Currently only table resources are indexed and searchable. +The search index is built with the databuilder elasticsearch publisher.

+

Metadata

+

The metadata service currently uses a Neo4j proxy to interact with Neo4j graph db and serves frontend service’s metadata. +The metadata is represented as a graph model: + +The above diagram shows how metadata is modeled in Amundsen.

+

Databuilder

+

Amundsen provides a data ingestion library for building the metadata. At Lyft, we build the metadata once a day +using an Airflow DAG (examples).

+

In addition to “real use” the databuilder is also employed as a handy tool to ingest some “pre-cooked” demo data used in the Quickstart guide. This allows you to have a supersmall sample of data to explore so many of the features in Amundsen are lit up without you even having to setup any connections to databases etc. to ingest real data.

+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/assets/images/favicon.png b/assets/images/favicon.png new file mode 100644 index 0000000000..1cf13b9f9d Binary files /dev/null and b/assets/images/favicon.png differ diff --git a/assets/javascripts/bundle.51d95adb.min.js b/assets/javascripts/bundle.51d95adb.min.js new file mode 100644 index 0000000000..b20ec6835b --- /dev/null +++ b/assets/javascripts/bundle.51d95adb.min.js @@ -0,0 +1,29 @@ +"use strict";(()=>{var Hi=Object.create;var xr=Object.defineProperty;var Pi=Object.getOwnPropertyDescriptor;var $i=Object.getOwnPropertyNames,kt=Object.getOwnPropertySymbols,Ii=Object.getPrototypeOf,Er=Object.prototype.hasOwnProperty,an=Object.prototype.propertyIsEnumerable;var on=(e,t,r)=>t in e?xr(e,t,{enumerable:!0,configurable:!0,writable:!0,value:r}):e[t]=r,P=(e,t)=>{for(var r in t||(t={}))Er.call(t,r)&&on(e,r,t[r]);if(kt)for(var r of kt(t))an.call(t,r)&&on(e,r,t[r]);return e};var sn=(e,t)=>{var r={};for(var n in e)Er.call(e,n)&&t.indexOf(n)<0&&(r[n]=e[n]);if(e!=null&&kt)for(var n of kt(e))t.indexOf(n)<0&&an.call(e,n)&&(r[n]=e[n]);return r};var Ht=(e,t)=>()=>(t||e((t={exports:{}}).exports,t),t.exports);var Fi=(e,t,r,n)=>{if(t&&typeof t=="object"||typeof t=="function")for(let o of $i(t))!Er.call(e,o)&&o!==r&&xr(e,o,{get:()=>t[o],enumerable:!(n=Pi(t,o))||n.enumerable});return e};var yt=(e,t,r)=>(r=e!=null?Hi(Ii(e)):{},Fi(t||!e||!e.__esModule?xr(r,"default",{value:e,enumerable:!0}):r,e));var fn=Ht((wr,cn)=>{(function(e,t){typeof wr=="object"&&typeof cn!="undefined"?t():typeof define=="function"&&define.amd?define(t):t()})(wr,function(){"use strict";function e(r){var n=!0,o=!1,i=null,a={text:!0,search:!0,url:!0,tel:!0,email:!0,password:!0,number:!0,date:!0,month:!0,week:!0,time:!0,datetime:!0,"datetime-local":!0};function s(T){return!!(T&&T!==document&&T.nodeName!=="HTML"&&T.nodeName!=="BODY"&&"classList"in T&&"contains"in T.classList)}function f(T){var Ke=T.type,We=T.tagName;return!!(We==="INPUT"&&a[Ke]&&!T.readOnly||We==="TEXTAREA"&&!T.readOnly||T.isContentEditable)}function c(T){T.classList.contains("focus-visible")||(T.classList.add("focus-visible"),T.setAttribute("data-focus-visible-added",""))}function u(T){T.hasAttribute("data-focus-visible-added")&&(T.classList.remove("focus-visible"),T.removeAttribute("data-focus-visible-added"))}function p(T){T.metaKey||T.altKey||T.ctrlKey||(s(r.activeElement)&&c(r.activeElement),n=!0)}function m(T){n=!1}function d(T){s(T.target)&&(n||f(T.target))&&c(T.target)}function h(T){s(T.target)&&(T.target.classList.contains("focus-visible")||T.target.hasAttribute("data-focus-visible-added"))&&(o=!0,window.clearTimeout(i),i=window.setTimeout(function(){o=!1},100),u(T.target))}function v(T){document.visibilityState==="hidden"&&(o&&(n=!0),B())}function B(){document.addEventListener("mousemove",z),document.addEventListener("mousedown",z),document.addEventListener("mouseup",z),document.addEventListener("pointermove",z),document.addEventListener("pointerdown",z),document.addEventListener("pointerup",z),document.addEventListener("touchmove",z),document.addEventListener("touchstart",z),document.addEventListener("touchend",z)}function re(){document.removeEventListener("mousemove",z),document.removeEventListener("mousedown",z),document.removeEventListener("mouseup",z),document.removeEventListener("pointermove",z),document.removeEventListener("pointerdown",z),document.removeEventListener("pointerup",z),document.removeEventListener("touchmove",z),document.removeEventListener("touchstart",z),document.removeEventListener("touchend",z)}function z(T){T.target.nodeName&&T.target.nodeName.toLowerCase()==="html"||(n=!1,re())}document.addEventListener("keydown",p,!0),document.addEventListener("mousedown",m,!0),document.addEventListener("pointerdown",m,!0),document.addEventListener("touchstart",m,!0),document.addEventListener("visibilitychange",v,!0),B(),r.addEventListener("focus",d,!0),r.addEventListener("blur",h,!0),r.nodeType===Node.DOCUMENT_FRAGMENT_NODE&&r.host?r.host.setAttribute("data-js-focus-visible",""):r.nodeType===Node.DOCUMENT_NODE&&(document.documentElement.classList.add("js-focus-visible"),document.documentElement.setAttribute("data-js-focus-visible",""))}if(typeof window!="undefined"&&typeof document!="undefined"){window.applyFocusVisiblePolyfill=e;var t;try{t=new CustomEvent("focus-visible-polyfill-ready")}catch(r){t=document.createEvent("CustomEvent"),t.initCustomEvent("focus-visible-polyfill-ready",!1,!1,{})}window.dispatchEvent(t)}typeof document!="undefined"&&e(document)})});var un=Ht(Sr=>{(function(e){var t=function(){try{return!!Symbol.iterator}catch(c){return!1}},r=t(),n=function(c){var u={next:function(){var p=c.shift();return{done:p===void 0,value:p}}};return r&&(u[Symbol.iterator]=function(){return u}),u},o=function(c){return encodeURIComponent(c).replace(/%20/g,"+")},i=function(c){return decodeURIComponent(String(c).replace(/\+/g," "))},a=function(){var c=function(p){Object.defineProperty(this,"_entries",{writable:!0,value:{}});var m=typeof p;if(m!=="undefined")if(m==="string")p!==""&&this._fromString(p);else if(p instanceof c){var d=this;p.forEach(function(re,z){d.append(z,re)})}else if(p!==null&&m==="object")if(Object.prototype.toString.call(p)==="[object Array]")for(var h=0;hd[0]?1:0}),c._entries&&(c._entries={});for(var p=0;p1?i(d[1]):"")}})})(typeof global!="undefined"?global:typeof window!="undefined"?window:typeof self!="undefined"?self:Sr);(function(e){var t=function(){try{var o=new e.URL("b","http://a");return o.pathname="c d",o.href==="http://a/c%20d"&&o.searchParams}catch(i){return!1}},r=function(){var o=e.URL,i=function(f,c){typeof f!="string"&&(f=String(f)),c&&typeof c!="string"&&(c=String(c));var u=document,p;if(c&&(e.location===void 0||c!==e.location.href)){c=c.toLowerCase(),u=document.implementation.createHTMLDocument(""),p=u.createElement("base"),p.href=c,u.head.appendChild(p);try{if(p.href.indexOf(c)!==0)throw new Error(p.href)}catch(T){throw new Error("URL unable to set base "+c+" due to "+T)}}var m=u.createElement("a");m.href=f,p&&(u.body.appendChild(m),m.href=m.href);var d=u.createElement("input");if(d.type="url",d.value=f,m.protocol===":"||!/:/.test(m.href)||!d.checkValidity()&&!c)throw new TypeError("Invalid URL");Object.defineProperty(this,"_anchorElement",{value:m});var h=new e.URLSearchParams(this.search),v=!0,B=!0,re=this;["append","delete","set"].forEach(function(T){var Ke=h[T];h[T]=function(){Ke.apply(h,arguments),v&&(B=!1,re.search=h.toString(),B=!0)}}),Object.defineProperty(this,"searchParams",{value:h,enumerable:!0});var z=void 0;Object.defineProperty(this,"_updateSearchParams",{enumerable:!1,configurable:!1,writable:!1,value:function(){this.search!==z&&(z=this.search,B&&(v=!1,this.searchParams._fromString(this.search),v=!0))}})},a=i.prototype,s=function(f){Object.defineProperty(a,f,{get:function(){return this._anchorElement[f]},set:function(c){this._anchorElement[f]=c},enumerable:!0})};["hash","host","hostname","port","protocol"].forEach(function(f){s(f)}),Object.defineProperty(a,"search",{get:function(){return this._anchorElement.search},set:function(f){this._anchorElement.search=f,this._updateSearchParams()},enumerable:!0}),Object.defineProperties(a,{toString:{get:function(){var f=this;return function(){return f.href}}},href:{get:function(){return this._anchorElement.href.replace(/\?$/,"")},set:function(f){this._anchorElement.href=f,this._updateSearchParams()},enumerable:!0},pathname:{get:function(){return this._anchorElement.pathname.replace(/(^\/?)/,"/")},set:function(f){this._anchorElement.pathname=f},enumerable:!0},origin:{get:function(){var f={"http:":80,"https:":443,"ftp:":21}[this._anchorElement.protocol],c=this._anchorElement.port!=f&&this._anchorElement.port!=="";return this._anchorElement.protocol+"//"+this._anchorElement.hostname+(c?":"+this._anchorElement.port:"")},enumerable:!0},password:{get:function(){return""},set:function(f){},enumerable:!0},username:{get:function(){return""},set:function(f){},enumerable:!0}}),i.createObjectURL=function(f){return o.createObjectURL.apply(o,arguments)},i.revokeObjectURL=function(f){return o.revokeObjectURL.apply(o,arguments)},e.URL=i};if(t()||r(),e.location!==void 0&&!("origin"in e.location)){var n=function(){return e.location.protocol+"//"+e.location.hostname+(e.location.port?":"+e.location.port:"")};try{Object.defineProperty(e.location,"origin",{get:n,enumerable:!0})}catch(o){setInterval(function(){e.location.origin=n()},100)}}})(typeof global!="undefined"?global:typeof window!="undefined"?window:typeof self!="undefined"?self:Sr)});var Qr=Ht((Lt,Kr)=>{/*! + * clipboard.js v2.0.11 + * https://clipboardjs.com/ + * + * Licensed MIT © Zeno Rocha + */(function(t,r){typeof Lt=="object"&&typeof Kr=="object"?Kr.exports=r():typeof define=="function"&&define.amd?define([],r):typeof Lt=="object"?Lt.ClipboardJS=r():t.ClipboardJS=r()})(Lt,function(){return function(){var e={686:function(n,o,i){"use strict";i.d(o,{default:function(){return ki}});var a=i(279),s=i.n(a),f=i(370),c=i.n(f),u=i(817),p=i.n(u);function m(j){try{return document.execCommand(j)}catch(O){return!1}}var d=function(O){var w=p()(O);return m("cut"),w},h=d;function v(j){var O=document.documentElement.getAttribute("dir")==="rtl",w=document.createElement("textarea");w.style.fontSize="12pt",w.style.border="0",w.style.padding="0",w.style.margin="0",w.style.position="absolute",w.style[O?"right":"left"]="-9999px";var k=window.pageYOffset||document.documentElement.scrollTop;return w.style.top="".concat(k,"px"),w.setAttribute("readonly",""),w.value=j,w}var B=function(O,w){var k=v(O);w.container.appendChild(k);var F=p()(k);return m("copy"),k.remove(),F},re=function(O){var w=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body},k="";return typeof O=="string"?k=B(O,w):O instanceof HTMLInputElement&&!["text","search","url","tel","password"].includes(O==null?void 0:O.type)?k=B(O.value,w):(k=p()(O),m("copy")),k},z=re;function T(j){return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?T=function(w){return typeof w}:T=function(w){return w&&typeof Symbol=="function"&&w.constructor===Symbol&&w!==Symbol.prototype?"symbol":typeof w},T(j)}var Ke=function(){var O=arguments.length>0&&arguments[0]!==void 0?arguments[0]:{},w=O.action,k=w===void 0?"copy":w,F=O.container,q=O.target,Le=O.text;if(k!=="copy"&&k!=="cut")throw new Error('Invalid "action" value, use either "copy" or "cut"');if(q!==void 0)if(q&&T(q)==="object"&&q.nodeType===1){if(k==="copy"&&q.hasAttribute("disabled"))throw new Error('Invalid "target" attribute. Please use "readonly" instead of "disabled" attribute');if(k==="cut"&&(q.hasAttribute("readonly")||q.hasAttribute("disabled")))throw new Error(`Invalid "target" attribute. You can't cut text from elements with "readonly" or "disabled" attributes`)}else throw new Error('Invalid "target" value, use a valid Element');if(Le)return z(Le,{container:F});if(q)return k==="cut"?h(q):z(q,{container:F})},We=Ke;function Ie(j){return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?Ie=function(w){return typeof w}:Ie=function(w){return w&&typeof Symbol=="function"&&w.constructor===Symbol&&w!==Symbol.prototype?"symbol":typeof w},Ie(j)}function Ti(j,O){if(!(j instanceof O))throw new TypeError("Cannot call a class as a function")}function nn(j,O){for(var w=0;w0&&arguments[0]!==void 0?arguments[0]:{};this.action=typeof F.action=="function"?F.action:this.defaultAction,this.target=typeof F.target=="function"?F.target:this.defaultTarget,this.text=typeof F.text=="function"?F.text:this.defaultText,this.container=Ie(F.container)==="object"?F.container:document.body}},{key:"listenClick",value:function(F){var q=this;this.listener=c()(F,"click",function(Le){return q.onClick(Le)})}},{key:"onClick",value:function(F){var q=F.delegateTarget||F.currentTarget,Le=this.action(q)||"copy",Rt=We({action:Le,container:this.container,target:this.target(q),text:this.text(q)});this.emit(Rt?"success":"error",{action:Le,text:Rt,trigger:q,clearSelection:function(){q&&q.focus(),window.getSelection().removeAllRanges()}})}},{key:"defaultAction",value:function(F){return yr("action",F)}},{key:"defaultTarget",value:function(F){var q=yr("target",F);if(q)return document.querySelector(q)}},{key:"defaultText",value:function(F){return yr("text",F)}},{key:"destroy",value:function(){this.listener.destroy()}}],[{key:"copy",value:function(F){var q=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body};return z(F,q)}},{key:"cut",value:function(F){return h(F)}},{key:"isSupported",value:function(){var F=arguments.length>0&&arguments[0]!==void 0?arguments[0]:["copy","cut"],q=typeof F=="string"?[F]:F,Le=!!document.queryCommandSupported;return q.forEach(function(Rt){Le=Le&&!!document.queryCommandSupported(Rt)}),Le}}]),w}(s()),ki=Ri},828:function(n){var o=9;if(typeof Element!="undefined"&&!Element.prototype.matches){var i=Element.prototype;i.matches=i.matchesSelector||i.mozMatchesSelector||i.msMatchesSelector||i.oMatchesSelector||i.webkitMatchesSelector}function a(s,f){for(;s&&s.nodeType!==o;){if(typeof s.matches=="function"&&s.matches(f))return s;s=s.parentNode}}n.exports=a},438:function(n,o,i){var a=i(828);function s(u,p,m,d,h){var v=c.apply(this,arguments);return u.addEventListener(m,v,h),{destroy:function(){u.removeEventListener(m,v,h)}}}function f(u,p,m,d,h){return typeof u.addEventListener=="function"?s.apply(null,arguments):typeof m=="function"?s.bind(null,document).apply(null,arguments):(typeof u=="string"&&(u=document.querySelectorAll(u)),Array.prototype.map.call(u,function(v){return s(v,p,m,d,h)}))}function c(u,p,m,d){return function(h){h.delegateTarget=a(h.target,p),h.delegateTarget&&d.call(u,h)}}n.exports=f},879:function(n,o){o.node=function(i){return i!==void 0&&i instanceof HTMLElement&&i.nodeType===1},o.nodeList=function(i){var a=Object.prototype.toString.call(i);return i!==void 0&&(a==="[object NodeList]"||a==="[object HTMLCollection]")&&"length"in i&&(i.length===0||o.node(i[0]))},o.string=function(i){return typeof i=="string"||i instanceof String},o.fn=function(i){var a=Object.prototype.toString.call(i);return a==="[object Function]"}},370:function(n,o,i){var a=i(879),s=i(438);function f(m,d,h){if(!m&&!d&&!h)throw new Error("Missing required arguments");if(!a.string(d))throw new TypeError("Second argument must be a String");if(!a.fn(h))throw new TypeError("Third argument must be a Function");if(a.node(m))return c(m,d,h);if(a.nodeList(m))return u(m,d,h);if(a.string(m))return p(m,d,h);throw new TypeError("First argument must be a String, HTMLElement, HTMLCollection, or NodeList")}function c(m,d,h){return m.addEventListener(d,h),{destroy:function(){m.removeEventListener(d,h)}}}function u(m,d,h){return Array.prototype.forEach.call(m,function(v){v.addEventListener(d,h)}),{destroy:function(){Array.prototype.forEach.call(m,function(v){v.removeEventListener(d,h)})}}}function p(m,d,h){return s(document.body,m,d,h)}n.exports=f},817:function(n){function o(i){var a;if(i.nodeName==="SELECT")i.focus(),a=i.value;else if(i.nodeName==="INPUT"||i.nodeName==="TEXTAREA"){var s=i.hasAttribute("readonly");s||i.setAttribute("readonly",""),i.select(),i.setSelectionRange(0,i.value.length),s||i.removeAttribute("readonly"),a=i.value}else{i.hasAttribute("contenteditable")&&i.focus();var f=window.getSelection(),c=document.createRange();c.selectNodeContents(i),f.removeAllRanges(),f.addRange(c),a=f.toString()}return a}n.exports=o},279:function(n){function o(){}o.prototype={on:function(i,a,s){var f=this.e||(this.e={});return(f[i]||(f[i]=[])).push({fn:a,ctx:s}),this},once:function(i,a,s){var f=this;function c(){f.off(i,c),a.apply(s,arguments)}return c._=a,this.on(i,c,s)},emit:function(i){var a=[].slice.call(arguments,1),s=((this.e||(this.e={}))[i]||[]).slice(),f=0,c=s.length;for(f;f{"use strict";/*! + * escape-html + * Copyright(c) 2012-2013 TJ Holowaychuk + * Copyright(c) 2015 Andreas Lubbe + * Copyright(c) 2015 Tiancheng "Timothy" Gu + * MIT Licensed + */var is=/["'&<>]/;Jo.exports=as;function as(e){var t=""+e,r=is.exec(t);if(!r)return t;var n,o="",i=0,a=0;for(i=r.index;i0&&i[i.length-1])&&(c[0]===6||c[0]===2)){r=0;continue}if(c[0]===3&&(!i||c[1]>i[0]&&c[1]=e.length&&(e=void 0),{value:e&&e[n++],done:!e}}};throw new TypeError(t?"Object is not iterable.":"Symbol.iterator is not defined.")}function W(e,t){var r=typeof Symbol=="function"&&e[Symbol.iterator];if(!r)return e;var n=r.call(e),o,i=[],a;try{for(;(t===void 0||t-- >0)&&!(o=n.next()).done;)i.push(o.value)}catch(s){a={error:s}}finally{try{o&&!o.done&&(r=n.return)&&r.call(n)}finally{if(a)throw a.error}}return i}function D(e,t,r){if(r||arguments.length===2)for(var n=0,o=t.length,i;n1||s(m,d)})})}function s(m,d){try{f(n[m](d))}catch(h){p(i[0][3],h)}}function f(m){m.value instanceof Xe?Promise.resolve(m.value.v).then(c,u):p(i[0][2],m)}function c(m){s("next",m)}function u(m){s("throw",m)}function p(m,d){m(d),i.shift(),i.length&&s(i[0][0],i[0][1])}}function mn(e){if(!Symbol.asyncIterator)throw new TypeError("Symbol.asyncIterator is not defined.");var t=e[Symbol.asyncIterator],r;return t?t.call(e):(e=typeof xe=="function"?xe(e):e[Symbol.iterator](),r={},n("next"),n("throw"),n("return"),r[Symbol.asyncIterator]=function(){return this},r);function n(i){r[i]=e[i]&&function(a){return new Promise(function(s,f){a=e[i](a),o(s,f,a.done,a.value)})}}function o(i,a,s,f){Promise.resolve(f).then(function(c){i({value:c,done:s})},a)}}function A(e){return typeof e=="function"}function at(e){var t=function(n){Error.call(n),n.stack=new Error().stack},r=e(t);return r.prototype=Object.create(Error.prototype),r.prototype.constructor=r,r}var $t=at(function(e){return function(r){e(this),this.message=r?r.length+` errors occurred during unsubscription: +`+r.map(function(n,o){return o+1+") "+n.toString()}).join(` + `):"",this.name="UnsubscriptionError",this.errors=r}});function De(e,t){if(e){var r=e.indexOf(t);0<=r&&e.splice(r,1)}}var Fe=function(){function e(t){this.initialTeardown=t,this.closed=!1,this._parentage=null,this._finalizers=null}return e.prototype.unsubscribe=function(){var t,r,n,o,i;if(!this.closed){this.closed=!0;var a=this._parentage;if(a)if(this._parentage=null,Array.isArray(a))try{for(var s=xe(a),f=s.next();!f.done;f=s.next()){var c=f.value;c.remove(this)}}catch(v){t={error:v}}finally{try{f&&!f.done&&(r=s.return)&&r.call(s)}finally{if(t)throw t.error}}else a.remove(this);var u=this.initialTeardown;if(A(u))try{u()}catch(v){i=v instanceof $t?v.errors:[v]}var p=this._finalizers;if(p){this._finalizers=null;try{for(var m=xe(p),d=m.next();!d.done;d=m.next()){var h=d.value;try{dn(h)}catch(v){i=i!=null?i:[],v instanceof $t?i=D(D([],W(i)),W(v.errors)):i.push(v)}}}catch(v){n={error:v}}finally{try{d&&!d.done&&(o=m.return)&&o.call(m)}finally{if(n)throw n.error}}}if(i)throw new $t(i)}},e.prototype.add=function(t){var r;if(t&&t!==this)if(this.closed)dn(t);else{if(t instanceof e){if(t.closed||t._hasParent(this))return;t._addParent(this)}(this._finalizers=(r=this._finalizers)!==null&&r!==void 0?r:[]).push(t)}},e.prototype._hasParent=function(t){var r=this._parentage;return r===t||Array.isArray(r)&&r.includes(t)},e.prototype._addParent=function(t){var r=this._parentage;this._parentage=Array.isArray(r)?(r.push(t),r):r?[r,t]:t},e.prototype._removeParent=function(t){var r=this._parentage;r===t?this._parentage=null:Array.isArray(r)&&De(r,t)},e.prototype.remove=function(t){var r=this._finalizers;r&&De(r,t),t instanceof e&&t._removeParent(this)},e.EMPTY=function(){var t=new e;return t.closed=!0,t}(),e}();var Or=Fe.EMPTY;function It(e){return e instanceof Fe||e&&"closed"in e&&A(e.remove)&&A(e.add)&&A(e.unsubscribe)}function dn(e){A(e)?e():e.unsubscribe()}var Ae={onUnhandledError:null,onStoppedNotification:null,Promise:void 0,useDeprecatedSynchronousErrorHandling:!1,useDeprecatedNextContext:!1};var st={setTimeout:function(e,t){for(var r=[],n=2;n0},enumerable:!1,configurable:!0}),t.prototype._trySubscribe=function(r){return this._throwIfClosed(),e.prototype._trySubscribe.call(this,r)},t.prototype._subscribe=function(r){return this._throwIfClosed(),this._checkFinalizedStatuses(r),this._innerSubscribe(r)},t.prototype._innerSubscribe=function(r){var n=this,o=this,i=o.hasError,a=o.isStopped,s=o.observers;return i||a?Or:(this.currentObservers=null,s.push(r),new Fe(function(){n.currentObservers=null,De(s,r)}))},t.prototype._checkFinalizedStatuses=function(r){var n=this,o=n.hasError,i=n.thrownError,a=n.isStopped;o?r.error(i):a&&r.complete()},t.prototype.asObservable=function(){var r=new U;return r.source=this,r},t.create=function(r,n){return new wn(r,n)},t}(U);var wn=function(e){ne(t,e);function t(r,n){var o=e.call(this)||this;return o.destination=r,o.source=n,o}return t.prototype.next=function(r){var n,o;(o=(n=this.destination)===null||n===void 0?void 0:n.next)===null||o===void 0||o.call(n,r)},t.prototype.error=function(r){var n,o;(o=(n=this.destination)===null||n===void 0?void 0:n.error)===null||o===void 0||o.call(n,r)},t.prototype.complete=function(){var r,n;(n=(r=this.destination)===null||r===void 0?void 0:r.complete)===null||n===void 0||n.call(r)},t.prototype._subscribe=function(r){var n,o;return(o=(n=this.source)===null||n===void 0?void 0:n.subscribe(r))!==null&&o!==void 0?o:Or},t}(E);var Et={now:function(){return(Et.delegate||Date).now()},delegate:void 0};var wt=function(e){ne(t,e);function t(r,n,o){r===void 0&&(r=1/0),n===void 0&&(n=1/0),o===void 0&&(o=Et);var i=e.call(this)||this;return i._bufferSize=r,i._windowTime=n,i._timestampProvider=o,i._buffer=[],i._infiniteTimeWindow=!0,i._infiniteTimeWindow=n===1/0,i._bufferSize=Math.max(1,r),i._windowTime=Math.max(1,n),i}return t.prototype.next=function(r){var n=this,o=n.isStopped,i=n._buffer,a=n._infiniteTimeWindow,s=n._timestampProvider,f=n._windowTime;o||(i.push(r),!a&&i.push(s.now()+f)),this._trimBuffer(),e.prototype.next.call(this,r)},t.prototype._subscribe=function(r){this._throwIfClosed(),this._trimBuffer();for(var n=this._innerSubscribe(r),o=this,i=o._infiniteTimeWindow,a=o._buffer,s=a.slice(),f=0;f0?e.prototype.requestAsyncId.call(this,r,n,o):(r.actions.push(this),r._scheduled||(r._scheduled=ut.requestAnimationFrame(function(){return r.flush(void 0)})))},t.prototype.recycleAsyncId=function(r,n,o){var i;if(o===void 0&&(o=0),o!=null?o>0:this.delay>0)return e.prototype.recycleAsyncId.call(this,r,n,o);var a=r.actions;n!=null&&((i=a[a.length-1])===null||i===void 0?void 0:i.id)!==n&&(ut.cancelAnimationFrame(n),r._scheduled=void 0)},t}(Ut);var On=function(e){ne(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t.prototype.flush=function(r){this._active=!0;var n=this._scheduled;this._scheduled=void 0;var o=this.actions,i;r=r||o.shift();do if(i=r.execute(r.state,r.delay))break;while((r=o[0])&&r.id===n&&o.shift());if(this._active=!1,i){for(;(r=o[0])&&r.id===n&&o.shift();)r.unsubscribe();throw i}},t}(Wt);var we=new On(Tn);var R=new U(function(e){return e.complete()});function Dt(e){return e&&A(e.schedule)}function kr(e){return e[e.length-1]}function Qe(e){return A(kr(e))?e.pop():void 0}function Se(e){return Dt(kr(e))?e.pop():void 0}function Vt(e,t){return typeof kr(e)=="number"?e.pop():t}var pt=function(e){return e&&typeof e.length=="number"&&typeof e!="function"};function zt(e){return A(e==null?void 0:e.then)}function Nt(e){return A(e[ft])}function qt(e){return Symbol.asyncIterator&&A(e==null?void 0:e[Symbol.asyncIterator])}function Kt(e){return new TypeError("You provided "+(e!==null&&typeof e=="object"?"an invalid object":"'"+e+"'")+" where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.")}function Ki(){return typeof Symbol!="function"||!Symbol.iterator?"@@iterator":Symbol.iterator}var Qt=Ki();function Yt(e){return A(e==null?void 0:e[Qt])}function Gt(e){return ln(this,arguments,function(){var r,n,o,i;return Pt(this,function(a){switch(a.label){case 0:r=e.getReader(),a.label=1;case 1:a.trys.push([1,,9,10]),a.label=2;case 2:return[4,Xe(r.read())];case 3:return n=a.sent(),o=n.value,i=n.done,i?[4,Xe(void 0)]:[3,5];case 4:return[2,a.sent()];case 5:return[4,Xe(o)];case 6:return[4,a.sent()];case 7:return a.sent(),[3,2];case 8:return[3,10];case 9:return r.releaseLock(),[7];case 10:return[2]}})})}function Bt(e){return A(e==null?void 0:e.getReader)}function $(e){if(e instanceof U)return e;if(e!=null){if(Nt(e))return Qi(e);if(pt(e))return Yi(e);if(zt(e))return Gi(e);if(qt(e))return _n(e);if(Yt(e))return Bi(e);if(Bt(e))return Ji(e)}throw Kt(e)}function Qi(e){return new U(function(t){var r=e[ft]();if(A(r.subscribe))return r.subscribe(t);throw new TypeError("Provided object does not correctly implement Symbol.observable")})}function Yi(e){return new U(function(t){for(var r=0;r=2;return function(n){return n.pipe(e?_(function(o,i){return e(o,i,n)}):me,Oe(1),r?He(t):zn(function(){return new Xt}))}}function Nn(){for(var e=[],t=0;t=2,!0))}function fe(e){e===void 0&&(e={});var t=e.connector,r=t===void 0?function(){return new E}:t,n=e.resetOnError,o=n===void 0?!0:n,i=e.resetOnComplete,a=i===void 0?!0:i,s=e.resetOnRefCountZero,f=s===void 0?!0:s;return function(c){var u,p,m,d=0,h=!1,v=!1,B=function(){p==null||p.unsubscribe(),p=void 0},re=function(){B(),u=m=void 0,h=v=!1},z=function(){var T=u;re(),T==null||T.unsubscribe()};return g(function(T,Ke){d++,!v&&!h&&B();var We=m=m!=null?m:r();Ke.add(function(){d--,d===0&&!v&&!h&&(p=jr(z,f))}),We.subscribe(Ke),!u&&d>0&&(u=new et({next:function(Ie){return We.next(Ie)},error:function(Ie){v=!0,B(),p=jr(re,o,Ie),We.error(Ie)},complete:function(){h=!0,B(),p=jr(re,a),We.complete()}}),$(T).subscribe(u))})(c)}}function jr(e,t){for(var r=[],n=2;ne.next(document)),e}function K(e,t=document){return Array.from(t.querySelectorAll(e))}function V(e,t=document){let r=se(e,t);if(typeof r=="undefined")throw new ReferenceError(`Missing element: expected "${e}" to be present`);return r}function se(e,t=document){return t.querySelector(e)||void 0}function _e(){return document.activeElement instanceof HTMLElement&&document.activeElement||void 0}function tr(e){return L(b(document.body,"focusin"),b(document.body,"focusout")).pipe(ke(1),l(()=>{let t=_e();return typeof t!="undefined"?e.contains(t):!1}),N(e===_e()),Y())}function Be(e){return{x:e.offsetLeft,y:e.offsetTop}}function Yn(e){return L(b(window,"load"),b(window,"resize")).pipe(Ce(0,we),l(()=>Be(e)),N(Be(e)))}function rr(e){return{x:e.scrollLeft,y:e.scrollTop}}function dt(e){return L(b(e,"scroll"),b(window,"resize")).pipe(Ce(0,we),l(()=>rr(e)),N(rr(e)))}var Bn=function(){if(typeof Map!="undefined")return Map;function e(t,r){var n=-1;return t.some(function(o,i){return o[0]===r?(n=i,!0):!1}),n}return function(){function t(){this.__entries__=[]}return Object.defineProperty(t.prototype,"size",{get:function(){return this.__entries__.length},enumerable:!0,configurable:!0}),t.prototype.get=function(r){var n=e(this.__entries__,r),o=this.__entries__[n];return o&&o[1]},t.prototype.set=function(r,n){var o=e(this.__entries__,r);~o?this.__entries__[o][1]=n:this.__entries__.push([r,n])},t.prototype.delete=function(r){var n=this.__entries__,o=e(n,r);~o&&n.splice(o,1)},t.prototype.has=function(r){return!!~e(this.__entries__,r)},t.prototype.clear=function(){this.__entries__.splice(0)},t.prototype.forEach=function(r,n){n===void 0&&(n=null);for(var o=0,i=this.__entries__;o0},e.prototype.connect_=function(){!zr||this.connected_||(document.addEventListener("transitionend",this.onTransitionEnd_),window.addEventListener("resize",this.refresh),xa?(this.mutationsObserver_=new MutationObserver(this.refresh),this.mutationsObserver_.observe(document,{attributes:!0,childList:!0,characterData:!0,subtree:!0})):(document.addEventListener("DOMSubtreeModified",this.refresh),this.mutationEventsAdded_=!0),this.connected_=!0)},e.prototype.disconnect_=function(){!zr||!this.connected_||(document.removeEventListener("transitionend",this.onTransitionEnd_),window.removeEventListener("resize",this.refresh),this.mutationsObserver_&&this.mutationsObserver_.disconnect(),this.mutationEventsAdded_&&document.removeEventListener("DOMSubtreeModified",this.refresh),this.mutationsObserver_=null,this.mutationEventsAdded_=!1,this.connected_=!1)},e.prototype.onTransitionEnd_=function(t){var r=t.propertyName,n=r===void 0?"":r,o=ya.some(function(i){return!!~n.indexOf(i)});o&&this.refresh()},e.getInstance=function(){return this.instance_||(this.instance_=new e),this.instance_},e.instance_=null,e}(),Jn=function(e,t){for(var r=0,n=Object.keys(t);r0},e}(),Zn=typeof WeakMap!="undefined"?new WeakMap:new Bn,eo=function(){function e(t){if(!(this instanceof e))throw new TypeError("Cannot call a class as a function.");if(!arguments.length)throw new TypeError("1 argument required, but only 0 present.");var r=Ea.getInstance(),n=new Ra(t,r,this);Zn.set(this,n)}return e}();["observe","unobserve","disconnect"].forEach(function(e){eo.prototype[e]=function(){var t;return(t=Zn.get(this))[e].apply(t,arguments)}});var ka=function(){return typeof nr.ResizeObserver!="undefined"?nr.ResizeObserver:eo}(),to=ka;var ro=new E,Ha=I(()=>H(new to(e=>{for(let t of e)ro.next(t)}))).pipe(x(e=>L(Te,H(e)).pipe(C(()=>e.disconnect()))),J(1));function de(e){return{width:e.offsetWidth,height:e.offsetHeight}}function ge(e){return Ha.pipe(S(t=>t.observe(e)),x(t=>ro.pipe(_(({target:r})=>r===e),C(()=>t.unobserve(e)),l(()=>de(e)))),N(de(e)))}function bt(e){return{width:e.scrollWidth,height:e.scrollHeight}}function ar(e){let t=e.parentElement;for(;t&&(e.scrollWidth<=t.scrollWidth&&e.scrollHeight<=t.scrollHeight);)t=(e=t).parentElement;return t?e:void 0}var no=new E,Pa=I(()=>H(new IntersectionObserver(e=>{for(let t of e)no.next(t)},{threshold:0}))).pipe(x(e=>L(Te,H(e)).pipe(C(()=>e.disconnect()))),J(1));function sr(e){return Pa.pipe(S(t=>t.observe(e)),x(t=>no.pipe(_(({target:r})=>r===e),C(()=>t.unobserve(e)),l(({isIntersecting:r})=>r))))}function oo(e,t=16){return dt(e).pipe(l(({y:r})=>{let n=de(e),o=bt(e);return r>=o.height-n.height-t}),Y())}var cr={drawer:V("[data-md-toggle=drawer]"),search:V("[data-md-toggle=search]")};function io(e){return cr[e].checked}function qe(e,t){cr[e].checked!==t&&cr[e].click()}function je(e){let t=cr[e];return b(t,"change").pipe(l(()=>t.checked),N(t.checked))}function $a(e,t){switch(e.constructor){case HTMLInputElement:return e.type==="radio"?/^Arrow/.test(t):!0;case HTMLSelectElement:case HTMLTextAreaElement:return!0;default:return e.isContentEditable}}function Ia(){return L(b(window,"compositionstart").pipe(l(()=>!0)),b(window,"compositionend").pipe(l(()=>!1))).pipe(N(!1))}function ao(){let e=b(window,"keydown").pipe(_(t=>!(t.metaKey||t.ctrlKey)),l(t=>({mode:io("search")?"search":"global",type:t.key,claim(){t.preventDefault(),t.stopPropagation()}})),_(({mode:t,type:r})=>{if(t==="global"){let n=_e();if(typeof n!="undefined")return!$a(n,r)}return!0}),fe());return Ia().pipe(x(t=>t?R:e))}function Me(){return new URL(location.href)}function ot(e){location.href=e.href}function so(){return new E}function co(e,t){if(typeof t=="string"||typeof t=="number")e.innerHTML+=t.toString();else if(t instanceof Node)e.appendChild(t);else if(Array.isArray(t))for(let r of t)co(e,r)}function M(e,t,...r){let n=document.createElement(e);if(t)for(let o of Object.keys(t))typeof t[o]!="undefined"&&(typeof t[o]!="boolean"?n.setAttribute(o,t[o]):n.setAttribute(o,""));for(let o of r)co(n,o);return n}function fr(e){if(e>999){let t=+((e-950)%1e3>99);return`${((e+1e-6)/1e3).toFixed(t)}k`}else return e.toString()}function fo(){return location.hash.substring(1)}function uo(e){let t=M("a",{href:e});t.addEventListener("click",r=>r.stopPropagation()),t.click()}function Fa(){return b(window,"hashchange").pipe(l(fo),N(fo()),_(e=>e.length>0),J(1))}function po(){return Fa().pipe(l(e=>se(`[id="${e}"]`)),_(e=>typeof e!="undefined"))}function Nr(e){let t=matchMedia(e);return Zt(r=>t.addListener(()=>r(t.matches))).pipe(N(t.matches))}function lo(){let e=matchMedia("print");return L(b(window,"beforeprint").pipe(l(()=>!0)),b(window,"afterprint").pipe(l(()=>!1))).pipe(N(e.matches))}function qr(e,t){return e.pipe(x(r=>r?t():R))}function ur(e,t={credentials:"same-origin"}){return ve(fetch(`${e}`,t)).pipe(ce(()=>R),x(r=>r.status!==200?Tt(()=>new Error(r.statusText)):H(r)))}function Ue(e,t){return ur(e,t).pipe(x(r=>r.json()),J(1))}function mo(e,t){let r=new DOMParser;return ur(e,t).pipe(x(n=>n.text()),l(n=>r.parseFromString(n,"text/xml")),J(1))}function pr(e){let t=M("script",{src:e});return I(()=>(document.head.appendChild(t),L(b(t,"load"),b(t,"error").pipe(x(()=>Tt(()=>new ReferenceError(`Invalid script: ${e}`))))).pipe(l(()=>{}),C(()=>document.head.removeChild(t)),Oe(1))))}function ho(){return{x:Math.max(0,scrollX),y:Math.max(0,scrollY)}}function bo(){return L(b(window,"scroll",{passive:!0}),b(window,"resize",{passive:!0})).pipe(l(ho),N(ho()))}function vo(){return{width:innerWidth,height:innerHeight}}function go(){return b(window,"resize",{passive:!0}).pipe(l(vo),N(vo()))}function yo(){return Q([bo(),go()]).pipe(l(([e,t])=>({offset:e,size:t})),J(1))}function lr(e,{viewport$:t,header$:r}){let n=t.pipe(X("size")),o=Q([n,r]).pipe(l(()=>Be(e)));return Q([r,t,o]).pipe(l(([{height:i},{offset:a,size:s},{x:f,y:c}])=>({offset:{x:a.x-f,y:a.y-c+i},size:s})))}(()=>{function e(n,o){parent.postMessage(n,o||"*")}function t(...n){return n.reduce((o,i)=>o.then(()=>new Promise(a=>{let s=document.createElement("script");s.src=i,s.onload=a,document.body.appendChild(s)})),Promise.resolve())}var r=class{constructor(n){this.url=n,this.onerror=null,this.onmessage=null,this.onmessageerror=null,this.m=a=>{a.source===this.w&&(a.stopImmediatePropagation(),this.dispatchEvent(new MessageEvent("message",{data:a.data})),this.onmessage&&this.onmessage(a))},this.e=(a,s,f,c,u)=>{if(s===this.url.toString()){let p=new ErrorEvent("error",{message:a,filename:s,lineno:f,colno:c,error:u});this.dispatchEvent(p),this.onerror&&this.onerror(p)}};let o=new EventTarget;this.addEventListener=o.addEventListener.bind(o),this.removeEventListener=o.removeEventListener.bind(o),this.dispatchEvent=o.dispatchEvent.bind(o);let i=document.createElement("iframe");i.width=i.height=i.frameBorder="0",document.body.appendChild(this.iframe=i),this.w.document.open(),this.w.document.write(` + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + + + + + +
+
+ + + +  + + + + + + +

OIDC Authentication

+

Setting up end-to-end authentication using OIDC is fairly simple and can be done using a Flask wrapper i.e., flaskoidc.

+

flaskoidc leverages the Flask’s before_request functionality to authenticate each request before passing that to +the views. It also accepts headers on each request if available in order to validate bearer token from incoming requests.

+

Installation

+

(If you are using flaskoidc<1.0.0, please follow the documentation here

+

PREREQUISITE: Please refer to the flaskoidc Documentation +for the installation and the configurations.

+

Note: You need to install and configure flaskoidc for each microservice of Amundsen +i.e., for frontendlibrary, metadatalibrary and searchlibrary in order to secure each of them.

+

Amundsen Configuration

+

Once you have flaskoidc installed and configured for each microservice, please set the following environment variables:

+
    +
  • +

    amundsenfrontendlibrary (amundsen/frontend): +

        FLASK_APP_MODULE_NAME: flaskoidc
    +    FLASK_APP_CLASS_NAME: FlaskOIDC
    +

    +
  • +
  • +

    amundsenmetadatalibrary (amundsen/metadata): +

        FLASK_APP_MODULE_NAME: flaskoidc
    +    FLASK_APP_CLASS_NAME: FlaskOIDC
    +

    +
  • +
  • +

    amundsensearchlibrary (amundsen/search): +

        FLASK_APP_MODULE_NAME: flaskoidc
    +    FLASK_APP_CLASS_NAME: FlaskOIDC
    +

    +
  • +
+

By default flaskoidc whitelist the healthcheck URLs, to not authenticate them. In case of metadatalibrary and searchlibrary +we may want to whitelist the healthcheck APIs explicitly using following environment variable.

+
    FLASK_OIDC_WHITELISTED_ENDPOINTS: 'api.healthcheck'
+
+

Setting Up Request Headers

+

To communicate securely between the microservices, you need to pass the bearer token from frontend in each request +to metadatalibrary and searchlibrary. This should be done using REQUEST_HEADERS_METHOD config variable in frontendlibrary.

+
    +
  • Define a function to add the bearer token in each request in your config.py:
  • +
+

version: flaskoidc<1.0.0 +

def get_access_headers(app):
+    try:
+        access_token = app.oidc.get_access_token()
+        return {'Authorization': 'Bearer {}'.format(access_token)}
+    except Exception:
+        return None
+
+version: flaskoidc>=1.0.0 +
from flask import Flask
+
+def get_access_headers(app: Flask) -> Optional[Dict]:
+    try:
+        # noinspection PyUnresolvedReferences
+        access_token = json.dumps(app.auth_client.token)
+        return {'Authorization': 'Bearer {}'.format(access_token)}
+    except Exception:
+        pass
+

+
    +
  • Set the method as the request header method in your config.py: +
    REQUEST_HEADERS_METHOD = get_access_headers
    +
  • +
+

This function will be called using the current app instance to add the headers in each request when calling any endpoint of +metadatalibrary and searchlibrary here

+

Setting Up Auth User Method

+

In order to get the current authenticated user (which is being used in Amundsen for many operations), we need to set +AUTH_USER_METHOD config variable in frontendlibrary. +This function should return email address, user id and any other required information.

+
    +
  • Define a function to fetch the user information in your config.py:
  • +
+

version: flaskoidc<1.0.0 +

from flask import Flask
+from amundsen_application.models.user import load_user, User
+
+
+def get_auth_user(app: Flask) -> User:
+    from flask import g
+    user_info = load_user(g.oidc_id_token)
+    return user_info
+
+version: flaskoidc>=1.0.0 +
from flask import Flask, session
+from amundsen_application.models.user import load_user, User
+
+
+def get_auth_user(app: Flask) -> User:
+    user_info = load_user(session.get("user"))
+    return user_info
+

+
    +
  • Set the method as the auth user method in your config.py: +
    AUTH_USER_METHOD = get_auth_user
    +
  • +
+

Once done, you’ll have the end-to-end authentication in Amundsen without any proxy or code changes.

+

Using Okta with Amundsen on K8s

+

Valid for flaskoidc<1.0.0

+

Assumptions:

+
    +
  • You have access to okta (you can create a developer account for free!)
  • +
  • +

    You are using k8s to setup amundsen. See amundsen-kube-helm

    +
  • +
  • +

    You need to have a stable DNS entry for amundsen-frontend that can be registered in okta.

    +
      +
    • for example in AWS you can setup route53 +I will assume for the rest of this tutorial that your stable uri is “http://amundsen-frontend
    • +
    +
  • +
  • You need to register amundsen in okta as an app. More info here. +But here are specific instructions for amundsen:
      +
    • At this time, I have only succesfully tested integration after ALL grants were checked.
    • +
    • Set the Login redirect URIs to: http://amundsen-frontend/oidc_callback
    • +
    • No need to set a logout redirect URI
    • +
    • Set the Initiate login URI to: http://amundsen-frontend/ + (This is where okta will take you if users click on amundsen via okta landing page)
    • +
    • Copy the Client ID and Client secret as you will need this later.
    • +
    +
  • +
  • At present, there is no oidc build of the frontend. So you will need to build an oidc build yourself and upload it to, for example ECR, for use by k8s. + You can then specify which image you want to use as a property override for your helm install like so:
  • +
+
frontEndServiceImage: 123.dkr.ecr.us-west-2.amazonaws.com/edmunds/amundsen-frontend:oidc-test
+
+

Please see further down in this doc for more instructions on how to build frontend. +4. When you start up helm you will need to provide some properties. Here are the properties that need to be overridden for oidc to work:

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
```yaml
+oidcEnabled: true
+createOidcSecret: true
+OIDC_CLIENT_ID: YOUR_CLIENT_ID
+OIDC_CLIENT_SECRET: YOUR_SECRET_ID
+OIDC_ORG_URL: https://amundsen.okta.com
+OIDC_AUTH_SERVER_ID: default
+# You also will need a custom oidc frontend build too
+frontEndServiceImage: 123.dkr.ecr.us-west-2.amazonaws.com/edmunds/amundsen-frontend:oidc-test
+```
+
+ +

Building frontend with OIDC

+
    +
  1. Please look at this guide for instructions on how to build a custom frontend docker image.
  2. +
  3. The only difference to above is that in your docker file you will want to add the following at the end. This will make sure its ready to go for oidc. +You can take alook at the public.Dockerfile as a reference.
  4. +
+
RUN pip3 install .[oidc]
+ENV FRONTEND_SVC_CONFIG_MODULE_CLASS=amundsen_application.oidc_config.OidcConfig
+ENV FLASK_APP_MODULE_NAME=flaskoidc
+ENV FLASK_APP_CLASS_NAME=FlaskOIDC
+ENV FLASK_OIDC_WHITELISTED_ENDPOINTS=status,healthcheck,health
+ENV SQLALCHEMY_DATABASE_URI=sqlite:///sessions.db
+
+

Please also take a look at this blog post for more detail.

+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/common/CHANGELOG/index.html b/common/CHANGELOG/index.html new file mode 100644 index 0000000000..d89658ffd8 --- /dev/null +++ b/common/CHANGELOG/index.html @@ -0,0 +1,1429 @@ + + + + + + + + + + + + + + + + + + + + + + + + CHANGELOG - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + +

CHANGELOG

+ + + +

Feature

+
    +
  • Added lineage item and lineage entities (#90) (f1c6011)
  • +
  • Add chart into common ES index map (#77) (4a7eea4)
  • +
  • Add chart to dashboard model (#73) (241f627)
  • +
  • Added badges field (optional) to column in table model (#68) (7bf5a84)
  • +
  • Add marshmallow packages to setup.py (#66) (7ff2fe1)
  • +
  • Tweaks for gremlin support (#60) (1a2733b)
  • +
  • Table model badges field update (#56) (6a393d0)
  • +
  • Added new badge model (#55) (09897d9)
  • +
  • Add github action for test and pypi publish (#47) (1a466b1)
  • +
  • Added resource_reports into Table model (60b1751)
  • +
+

Fix

+
    +
  • Standardize requirements and fixes for marshmallow3+ (#98) (d185046)
  • +
  • Moved version declaration (#88) (19be687)
  • +
  • Fix table index map bug (#86) (f250d6a)
  • +
  • Make column names searchable by lowercase (#85) (0ead455)
  • +
  • Changed marshmallow-annotation version, temp solution (#81) (ff9d2e2)
  • +
  • Enable flake8 and mypy in CI (#75) (32e317c)
  • +
  • Fix import (#74) (2d1725b)
  • +
  • Add dashboard index map copied from amundsendatabuilder (#65) (551834b)
  • +
  • Update elasticsearch mapping (#64) (b43a687)
  • +
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/common/LICENSE b/common/LICENSE new file mode 100644 index 0000000000..bed437514f --- /dev/null +++ b/common/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/common/MANIFEST.in b/common/MANIFEST.in new file mode 100644 index 0000000000..b36835faa2 --- /dev/null +++ b/common/MANIFEST.in @@ -0,0 +1,3 @@ +include requirements.txt + +global-include requirements-*.txt diff --git a/common/Makefile b/common/Makefile new file mode 100644 index 0000000000..a30dccf877 --- /dev/null +++ b/common/Makefile @@ -0,0 +1,23 @@ +clean: + find . -name \*.pyc -delete + find . -name __pycache__ -delete + rm -rf dist/ + +test_unit: + python -m pytest tests + python3 -bb -m pytest tests + + +lint: + flake8 . + +.PHONY: mypy +mypy: + mypy . + + +test: test_unit lint mypy + +.PHONY: install_deps +install_deps: + pip3 install -e ".[all]" diff --git a/common/NOTICE b/common/NOTICE new file mode 100644 index 0000000000..f7ca551896 --- /dev/null +++ b/common/NOTICE @@ -0,0 +1,4 @@ +amundsencommon +Copyright 2019-2020 Lyft Inc. + +This product includes software developed at Lyft Inc. diff --git a/common/amundsen_common/__init__.py b/common/amundsen_common/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/entity/__init__.py b/common/amundsen_common/entity/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/entity/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/entity/resource_type.py b/common/amundsen_common/entity/resource_type.py new file mode 100644 index 0000000000..97dc74bcd2 --- /dev/null +++ b/common/amundsen_common/entity/resource_type.py @@ -0,0 +1,21 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum, auto + + +class ResourceType(Enum): + Table = auto() + Dashboard = auto() + User = auto() + Column = auto() + Type_Metadata = auto() + Feature = auto() + + +def to_resource_type(*, label: str) -> ResourceType: + return ResourceType[label.title()] + + +def to_label(*, resource_type: ResourceType) -> str: + return resource_type.name.lower() diff --git a/common/amundsen_common/log/__init__.py b/common/amundsen_common/log/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/log/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/log/action_log.py b/common/amundsen_common/log/action_log.py new file mode 100644 index 0000000000..b1ac798ec1 --- /dev/null +++ b/common/amundsen_common/log/action_log.py @@ -0,0 +1,99 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import functools + +import json +import logging +import socket +from datetime import datetime, timezone, timedelta + +from typing import Any, Dict, Callable +from flask import current_app as flask_app +from amundsen_common.log import action_log_callback +from amundsen_common.log.action_log_model import ActionLogParams + +LOGGER = logging.getLogger(__name__) +EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) # use POSIX epoch + +# CONFIG KEY FOR caller_retrieval instance +CALLER_RETRIEVAL_INSTANCE_KEY = 'CALLER_RETRIEVAL_INSTANCE' + + +def action_logging(f: Callable[..., Any]) -> Any: + """ + Decorates function to execute function at the same time triggering action logger callbacks. + It will call action logger callbacks twice, one for pre-execution and the other one for post-execution. + Action logger will be called with ActionLogParams + + :param f: function instance + :return: wrapped function + """ + @functools.wraps(f) + def wrapper( + *args: Any, + **kwargs: Any + ) -> Any: + """ + An wrapper for api functions. It creates ActionLogParams based on the function name, positional arguments, + and keyword arguments. + + :param args: A passthrough positional arguments. + :param kwargs: A passthrough keyword argument + """ + metrics = _build_metrics(f.__name__, *args, **kwargs) + action_log_callback.on_pre_execution(ActionLogParams(**metrics)) + output = None + try: + output = f(*args, **kwargs) + return output + except Exception as e: + metrics['error'] = e + raise + finally: + metrics['end_epoch_ms'] = get_epoch_millisec() + try: + metrics['output'] = json.dumps(output) + except Exception: + metrics['output'] = output + + action_log_callback.on_post_execution(ActionLogParams(**metrics)) + + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug('action has been logged') + + return wrapper + + +def get_epoch_millisec() -> int: + return (datetime.now(timezone.utc) - EPOCH) // timedelta(milliseconds=1) + + +def _build_metrics( + func_name: str, + *args: Any, + **kwargs: Any +) -> Dict[str, Any]: + """ + Builds metrics dict from function args + :param func_name: + :param args: + :param kwargs: + :return: Dict that matches ActionLogParams variable + """ + + metrics = { + 'command': kwargs.get('command', func_name), + 'start_epoch_ms': get_epoch_millisec(), + 'host_name': socket.gethostname(), + 'pos_args_json': json.dumps(args), + 'keyword_args_json': json.dumps(kwargs), + } # type: Dict[str, Any] + + caller_retriever = flask_app.config.get(CALLER_RETRIEVAL_INSTANCE_KEY, '') + if caller_retriever: + metrics['user'] = caller_retriever.get_caller() + else: + metrics['user'] = 'UNKNOWN' + + return metrics diff --git a/common/amundsen_common/log/action_log_callback.py b/common/amundsen_common/log/action_log_callback.py new file mode 100644 index 0000000000..37a22d85e8 --- /dev/null +++ b/common/amundsen_common/log/action_log_callback.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +An Action Logger module. Singleton pattern has been applied into this module +so that registered callbacks can be used all through the same python process. +""" + +import logging +import sys +from typing import Callable, List, Any + +from pkg_resources import iter_entry_points + +from amundsen_common.log.action_log_model import ActionLogParams + +LOGGER = logging.getLogger(__name__) + +__pre_exec_callbacks: List[Callable[..., Any]] = [] +__post_exec_callbacks: List[Callable[..., Any]] = [] + + +def register_pre_exec_callback(action_log_callback: Callable[..., Any]) -> None: + """ + Registers more action_logger function callback for pre-execution. This function callback is expected to be called + with keyword args. For more about the arguments that is being passed to the callback, refer to + amundsen_application.log.action_log_model.ActionLogParams + :param action_logger: An action logger callback function + :return: None + """ + LOGGER.debug("Adding {} to pre execution callback".format(action_log_callback)) + __pre_exec_callbacks.append(action_log_callback) + + +def register_post_exec_callback(action_log_callback: Callable[..., Any]) -> None: + """ + Registers more action_logger function callback for post-execution. This function callback is expected to be + called with keyword args. For more about the arguments that is being passed to the callback, + amundsen_application.log.action_log_model.ActionLogParams + :param action_logger: An action logger callback function + :return: None + """ + LOGGER.debug("Adding {} to post execution callback".format(action_log_callback)) + __post_exec_callbacks.append(action_log_callback) + + +def on_pre_execution(action_log_params: ActionLogParams) -> None: + """ + Calls callbacks before execution. + Note that any exception from callback will be logged but won't be propagated. + :param kwargs: + :return: None + """ + LOGGER.debug("Calling callbacks: {}".format(__pre_exec_callbacks)) + for call_back_function in __pre_exec_callbacks: + try: + call_back_function(action_log_params) + except Exception: + logging.exception('Failed on pre-execution callback using {}'.format(call_back_function)) + + +def on_post_execution(action_log_params: ActionLogParams) -> None: + """ + Calls callbacks after execution. As it's being called after execution, it can capture most of fields in + amundsen_application.log.action_log_model.ActionLogParams. Note that any exception from callback will be logged + but won't be propagated. + :param kwargs: + :return: None + """ + LOGGER.debug("Calling callbacks: {}".format(__post_exec_callbacks)) + for call_back_function in __post_exec_callbacks: + try: + call_back_function(action_log_params) + except Exception: + logging.exception('Failed on post-execution callback using {}'.format(call_back_function)) + + +def logging_action_log(action_log_params: ActionLogParams) -> None: + """ + An action logger callback that just logs the ActionLogParams that it receives. + :param **kwargs keyword arguments + :return: None + """ + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug('logging_action_log: {}'.format(action_log_params)) + + +def register_action_logs() -> None: + """ + Retrieve declared action log callbacks from entry point where there are two groups that can be registered: + 1. "action_log.post_exec.plugin": callback for pre-execution + 2. "action_log.pre_exec.plugin": callback for post-execution + :return: None + """ + for entry_point in iter_entry_points(group='action_log.post_exec.plugin', name=None): + print('Registering post_exec action_log entry_point: {}'.format(entry_point), file=sys.stderr) + register_post_exec_callback(entry_point.load()) + + for entry_point in iter_entry_points(group='action_log.pre_exec.plugin', name=None): + print('Registering pre_exec action_log entry_point: {}'.format(entry_point), file=sys.stderr) + register_pre_exec_callback(entry_point.load()) + + +register_action_logs() diff --git a/common/amundsen_common/log/action_log_model.py b/common/amundsen_common/log/action_log_model.py new file mode 100644 index 0000000000..ffdf339243 --- /dev/null +++ b/common/amundsen_common/log/action_log_model.py @@ -0,0 +1,46 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + + +class ActionLogParams(object): + """ + Holds parameters for Action log + """ + def __init__( + self, *, + command: str, + start_epoch_ms: int, + end_epoch_ms: Optional[int] = 0, + user: str, + host_name: str, + pos_args_json: str, + keyword_args_json: str, + output: Any = None, + error: Optional[Exception] = None + ) -> None: + self.command = command + self.start_epoch_ms = start_epoch_ms + self.end_epoch_ms = end_epoch_ms + self.user = user + self.host_name = host_name + self.pos_args_json = pos_args_json + self.keyword_args_json = keyword_args_json + self.output = output + self.error = error + + def __repr__(self) -> str: + return 'ActionLogParams(command={!r}, start_epoch_ms={!r}, end_epoch_ms={!r}, user={!r}, ' \ + 'host_name={!r}, pos_args_json={!r}, keyword_args_json={!r}, output={!r}, error={!r})'\ + .format( + self.command, + self.start_epoch_ms, + self.end_epoch_ms, + self.user, + self.host_name, + self.pos_args_json, + self.keyword_args_json, + self.output, + self.error, + ) diff --git a/common/amundsen_common/log/auth_caller_retrieval.py b/common/amundsen_common/log/auth_caller_retrieval.py new file mode 100644 index 0000000000..4d0787a7fc --- /dev/null +++ b/common/amundsen_common/log/auth_caller_retrieval.py @@ -0,0 +1,15 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import getpass + +from flask import current_app as flask_app + +from amundsen_common.log.caller_retrieval import BaseCallerRetriever + + +class AuthCallerRetrieval(BaseCallerRetriever): + def get_caller(self) -> str: + if flask_app.config.get('AUTH_USER_METHOD', None): + return flask_app.config['AUTH_USER_METHOD'](flask_app).email + return getpass.getuser() diff --git a/common/amundsen_common/log/caller_retrieval.py b/common/amundsen_common/log/caller_retrieval.py new file mode 100644 index 0000000000..a150169796 --- /dev/null +++ b/common/amundsen_common/log/caller_retrieval.py @@ -0,0 +1,11 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABCMeta, abstractmethod + + +class BaseCallerRetriever(metaclass=ABCMeta): + + @abstractmethod + def get_caller(self) -> str: + pass diff --git a/common/amundsen_common/log/http_header_caller_retrieval.py b/common/amundsen_common/log/http_header_caller_retrieval.py new file mode 100644 index 0000000000..b1546fe33a --- /dev/null +++ b/common/amundsen_common/log/http_header_caller_retrieval.py @@ -0,0 +1,15 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from flask import current_app as flask_app +from flask import request + +from amundsen_common.log.caller_retrieval import BaseCallerRetriever + +CALLER_HEADER_KEY = 'CALLER_HEADER_KEY' + + +class HttpHeaderCallerRetrieval(BaseCallerRetriever): + def get_caller(self) -> str: + header_key = flask_app.config.get(CALLER_HEADER_KEY, 'user-agent') + return request.headers.get(header_key) or 'UNKNOWN' diff --git a/common/amundsen_common/models/__init__.py b/common/amundsen_common/models/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/models/api/__init__.py b/common/amundsen_common/models/api/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/models/api/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/models/api/health_check.py b/common/amundsen_common/models/api/health_check.py new file mode 100644 index 0000000000..3459699880 --- /dev/null +++ b/common/amundsen_common/models/api/health_check.py @@ -0,0 +1,39 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import attr +from http import HTTPStatus +from typing import Any, Dict + +from marshmallow3_annotations.ext.attrs import AttrsSchema + +OK = 'ok' +FAIL = 'fail' +_ELIGIBLE_HEALTH_CHECKS = [OK, FAIL] +_HEALTH_CHECK_HTTP_STATUS_MAP = { + OK: HTTPStatus.OK, + FAIL: HTTPStatus.SERVICE_UNAVAILABLE +} + + +@attr.s(auto_attribs=True, kw_only=True) +class HealthCheck: + status: str = attr.ib() + checks: Dict[str, Any] = dict() + + @status.validator + def vaildate_status(self, attribute: str, value: Any) -> None: + if value not in _ELIGIBLE_HEALTH_CHECKS: + raise ValueError(f"status must be one of {_ELIGIBLE_HEALTH_CHECKS}") + + def get_http_status(self) -> int: + return _HEALTH_CHECK_HTTP_STATUS_MAP[self.status] + + def dict(self) -> Dict[str, Any]: + return attr.asdict(self) # type: ignore + + +class HealthCheckSchema(AttrsSchema): + class Meta: + target = HealthCheck + register_as_scheme = True diff --git a/common/amundsen_common/models/badge.py b/common/amundsen_common/models/badge.py new file mode 100644 index 0000000000..2b941e6386 --- /dev/null +++ b/common/amundsen_common/models/badge.py @@ -0,0 +1,18 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import attr + +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class Badge: + badge_name: str = attr.ib() + category: str = attr.ib() + + +class BadgeSchema(AttrsSchema): + class Meta: + target = Badge + register_as_scheme = True diff --git a/common/amundsen_common/models/dashboard.py b/common/amundsen_common/models/dashboard.py new file mode 100644 index 0000000000..92ac85536b --- /dev/null +++ b/common/amundsen_common/models/dashboard.py @@ -0,0 +1,27 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import attr +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class DashboardSummary: + uri: str = attr.ib() + cluster: str = attr.ib() + group_name: str = attr.ib() + group_url: str = attr.ib() + product: str = attr.ib() + name: str = attr.ib() + url: str = attr.ib() + description: Optional[str] = None + last_successful_run_timestamp: Optional[float] = None + chart_names: Optional[List[str]] = [] + + +class DashboardSummarySchema(AttrsSchema): + class Meta: + target = DashboardSummary + register_as_scheme = True diff --git a/common/amundsen_common/models/feature.py b/common/amundsen_common/models/feature.py new file mode 100644 index 0000000000..f54de742b8 --- /dev/null +++ b/common/amundsen_common/models/feature.py @@ -0,0 +1,95 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Dict, Any + +import attr + +from amundsen_common.models.user import User +from amundsen_common.models.badge import Badge +from amundsen_common.models.tag import Tag +from amundsen_common.models.table import ProgrammaticDescription +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class ColumnItem: + column_name: str + column_type: str + + +class ColumnItemSchema(AttrsSchema): + class Meta: + target = ColumnItem + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class DataSample: + # Modeled after preview data model in FE + columns: List[ColumnItem] + data: List[Dict[str, Any]] + error_text: str + + +class DataSampleSchema(AttrsSchema): + class Meta: + target = DataSample + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class FeatureWatermark: + key: Optional[str] + watermark_type: Optional[str] + time: str + + +class FeatureWatermarkSchema(AttrsSchema): + class Meta: + target = FeatureWatermark + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Feature: + key: Optional[str] = attr.ib(default=None) + name: str + version: str # ex: "1.2.0" + status: Optional[str] + feature_group: str + entity: Optional[str] + data_type: Optional[str] + availability: List[str] + description: Optional[str] = attr.ib(default=None) + owners: List[User] + badges: List[Badge] + tags: List[Tag] + programmatic_descriptions: List[ProgrammaticDescription] + watermarks: List[FeatureWatermark] + last_updated_timestamp: Optional[int] + created_timestamp: Optional[int] + + +class FeatureSchema(AttrsSchema): + class Meta: + target = Feature + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class FeatureSummary: + key: str # ex: test_feature_group_name/test_feature_name/1.2.0 + name: str + version: str + availability: List[str] + entity: List[str] + description: Optional[str] = attr.ib(default=None) + badges: List[Badge] + last_updated_timestamp: Optional[int] + + +class FeatureSummarySchema(AttrsSchema): + class Meta: + target = FeatureSummary + register_as_scheme = True diff --git a/common/amundsen_common/models/generation_code.py b/common/amundsen_common/models/generation_code.py new file mode 100644 index 0000000000..d22ecdb300 --- /dev/null +++ b/common/amundsen_common/models/generation_code.py @@ -0,0 +1,21 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import attr + +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class GenerationCode: + key: Optional[str] + text: str + source: Optional[str] + + +class GenerationCodeSchema(AttrsSchema): + class Meta: + target = GenerationCode + register_as_scheme = True diff --git a/common/amundsen_common/models/index_map.py b/common/amundsen_common/models/index_map.py new file mode 100644 index 0000000000..4deeba243c --- /dev/null +++ b/common/amundsen_common/models/index_map.py @@ -0,0 +1,331 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import textwrap + +# Specifying default mapping for elasticsearch index +# Documentation: https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping.html +# Setting type to "text" for all fields that would be used in search +# Using Simple Analyzer to convert all text into search terms +# https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-simple-analyzer.html +# Standard Analyzer is used for all text fields that don't explicitly specify an analyzer +# https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-standard-analyzer.html +TABLE_INDEX_MAP = textwrap.dedent( + """ + { + "settings": { + "analysis": { + "normalizer": { + "column_names_normalizer": { + "type": "custom", + "filter": ["lowercase"] + } + } + } + }, + "mappings": { + "properties": { + "name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "schema": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "display_name": { + "type": "keyword" + }, + "last_updated_timestamp": { + "type": "date", + "format": "epoch_second" + }, + "description": { + "type": "text", + "analyzer": "simple" + }, + "column_names": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword", + "normalizer": "column_names_normalizer" + } + } + }, + "column_descriptions": { + "type": "text", + "analyzer": "simple" + }, + "tags": { + "type": "keyword" + }, + "badges": { + "type": "keyword" + }, + "cluster": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "database": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "key": { + "type": "keyword" + }, + "total_usage": { + "type": "long" + }, + "unique_usage": { + "type": "long" + }, + "programmatic_descriptions": { + "type": "text", + "analyzer": "simple" + } + } + } + } + """ +) + +DASHBOARD_ELASTICSEARCH_INDEX_MAPPING = textwrap.dedent( + """ + { + "settings": { + "analysis": { + "normalizer": { + "lowercase_normalizer": { + "type": "custom", + "char_filter": [], + "filter": ["lowercase", "asciifolding"] + } + } + } + }, + "mappings": { + "properties": { + "group_name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword", + "normalizer": "lowercase_normalizer" + } + } + }, + "name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword", + "normalizer": "lowercase_normalizer" + } + } + }, + "description": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "group_description": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "query_names": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "chart_names": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "tags": { + "type": "keyword" + }, + "badges": { + "type": "keyword" + } + } + } + } + """ +) + +USER_INDEX_MAP = textwrap.dedent( + """ + { + "mappings": { + "properties": { + "email": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "first_name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "last_name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "full_name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "total_read": { + "type": "long" + }, + "total_own": { + "type": "long" + }, + "total_follow": { + "type": "long" + } + } + } + } + """ +) + +FEATURE_INDEX_MAP = textwrap.dedent( + """ + { + "settings": { + "analysis": { + "normalizer": { + "lowercase_normalizer": { + "type": "custom", + "filter": ["lowercase"] + } + } + } + }, + "mappings": { + "properties": { + "feature_group": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword", + "normalizer": "lowercase_normalizer" + } + } + }, + "feature_name": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword", + "normalizer": "lowercase_normalizer" + } + } + }, + "version": { + "type": "keyword", + "normalizer": "lowercase_normalizer" + }, + "key": { + "type": "keyword" + }, + "total_usage": { + "type": "long" + }, + "status": { + "type": "keyword" + }, + "entity": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "availability": { + "type": "text", + "analyzer": "simple", + "fields": { + "raw": { + "type": "keyword" + } + } + }, + "badges": { + "type": "keyword" + }, + "tags": { + "type": "keyword" + }, + "last_updated_timestamp": { + "type": "date", + "format": "epoch_second" + } + } + } +} +""" +) diff --git a/common/amundsen_common/models/lineage.py b/common/amundsen_common/models/lineage.py new file mode 100644 index 0000000000..cdf3c69154 --- /dev/null +++ b/common/amundsen_common/models/lineage.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, List + +from amundsen_common.models.badge import Badge + +import attr +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class LineageItem: + key: str # down/upstream table/col/task key + level: int # upstream/downstream distance from current resource + source: str # database this resource is from + badges: Optional[List[Badge]] = None + usage: Optional[int] = None # statistic to sort lineage items by + parent: Optional[str] = None # key of the parent entity, used to create the relationships in graph + link: Optional[str] = None # internal link to redirect to different than the resource details page + in_amundsen: Optional[bool] = None # it is possible to have lineage that doesn't exist in Amundsen in that moment + + +class LineageItemSchema(AttrsSchema): + class Meta: + target = LineageItem + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Lineage: + key: str # current table/col/task key + direction: str # upstream/downstream/both + depth: int # how many levels up/down 0 == all + upstream_entities: List[LineageItem] # list of upstream entities + downstream_entities: List[LineageItem] # list of downstream entities + upstream_count: Optional[int] = None # number of total upstream entities + downstream_count: Optional[int] = None # number of total downstream entities + + +class LineageSchema(AttrsSchema): + class Meta: + target = Lineage + register_as_scheme = True diff --git a/common/amundsen_common/models/popular_table.py b/common/amundsen_common/models/popular_table.py new file mode 100644 index 0000000000..cccfd6dfb2 --- /dev/null +++ b/common/amundsen_common/models/popular_table.py @@ -0,0 +1,28 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import attr +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class PopularTable: + """ + DEPRECATED. Use TableSummary + """ + database: str = attr.ib() + cluster: str = attr.ib() + schema: str = attr.ib() + name: str = attr.ib() + description: Optional[str] = None + + +class PopularTableSchema(AttrsSchema): + """ + DEPRECATED. Use TableSummary + """ + class Meta: + target = PopularTable + register_as_scheme = True diff --git a/common/amundsen_common/models/search.py b/common/amundsen_common/models/search.py new file mode 100644 index 0000000000..b298c2dbb1 --- /dev/null +++ b/common/amundsen_common/models/search.py @@ -0,0 +1,79 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, List, Optional, Dict + +import attr + +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class HighlightOptions: + enable_highlight: bool = False + + +class HighlightOptionsSchema(AttrsSchema): + class Meta: + target = HighlightOptions + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Filter: + name: str + values: List[str] + operation: str + + +class FilterSchema(AttrsSchema): + class Meta: + target = Filter + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class SearchRequest: + query_term: str + resource_types: List[str] = [] + page_index: Optional[int] = 0 + results_per_page: Optional[int] = 10 + filters: List[Filter] = [] + # highlight options are defined per resource + highlight_options: Optional[Dict[str, HighlightOptions]] = {} + + +class SearchRequestSchema(AttrsSchema): + class Meta: + target = SearchRequest + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class SearchResponse: + msg: str + page_index: int + results_per_page: int + results: Dict[str, Any] + status_code: int + + +class SearchResponseSchema(AttrsSchema): + class Meta: + target = SearchResponse + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class UpdateDocumentRequest: + resource_key: str + resource_type: str + field: str + value: Optional[str] + operation: str # can be add or overwrite + + +class UpdateDocumentRequestSchema(AttrsSchema): + class Meta: + target = UpdateDocumentRequest + register_as_scheme = True diff --git a/common/amundsen_common/models/table.py b/common/amundsen_common/models/table.py new file mode 100644 index 0000000000..7ded9e1213 --- /dev/null +++ b/common/amundsen_common/models/table.py @@ -0,0 +1,218 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import attr + +from amundsen_common.models.user import User +from amundsen_common.models.badge import Badge +from amundsen_common.models.tag import Tag +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class Reader: + user: User + read_count: int + + +class ReaderSchema(AttrsSchema): + class Meta: + target = Reader + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Watermark: + watermark_type: Optional[str] = None + partition_key: Optional[str] = None + partition_value: Optional[str] = None + create_time: Optional[str] = None + + +class WatermarkSchema(AttrsSchema): + class Meta: + target = Watermark + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Stat: + stat_type: str + stat_val: Optional[str] = None + start_epoch: Optional[int] = None + end_epoch: Optional[int] = None + is_metric: Optional[bool] = None + + +class StatSchema(AttrsSchema): + class Meta: + target = Stat + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class TypeMetadata: + kind: str + name: str + key: str + description: Optional[str] = None + data_type: str + sort_order: int + badges: List[Badge] = [] + children: List['TypeMetadata'] = [] + + +class TypeMetadataSchema(AttrsSchema): + class Meta: + target = TypeMetadata + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Column: + name: str + key: Optional[str] = None + description: Optional[str] = None + col_type: str + sort_order: int + stats: List[Stat] = [] + badges: Optional[List[Badge]] = [] + type_metadata: Optional[TypeMetadata] = None # Used to support complex column types + + +class ColumnSchema(AttrsSchema): + class Meta: + target = Column + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Application: + application_url: Optional[str] = None + description: Optional[str] = None + id: str + name: Optional[str] = None + kind: Optional[str] = None + + +class ApplicationSchema(AttrsSchema): + class Meta: + target = Application + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Source: + source_type: str + source: str + + +class SourceSchema(AttrsSchema): + class Meta: + target = Source + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class ResourceReport: + name: str + url: str + + +class ResourceReportSchema(AttrsSchema): + class Meta: + target = ResourceReport + register_as_scheme = True + + +# this is a temporary hack to satisfy mypy. Once https://github.com/python/mypy/issues/6136 is resolved, use +# `attr.converters.default_if_none(default=False)` +def default_if_none(arg: Optional[bool]) -> bool: + return arg or False + + +@attr.s(auto_attribs=True, kw_only=True) +class ProgrammaticDescription: + source: str + text: str + + +class ProgrammaticDescriptionSchema(AttrsSchema): + class Meta: + target = ProgrammaticDescription + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class TableSummary: + database: str = attr.ib() + cluster: str = attr.ib() + schema: str = attr.ib() + name: str = attr.ib() + description: Optional[str] = attr.ib(default=None) + schema_description: Optional[str] = attr.ib(default=None) + + +class TableSummarySchema(AttrsSchema): + class Meta: + target = TableSummary + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class SqlJoin: + column: str + joined_on_table: TableSummary + joined_on_column: str + join_type: str + join_sql: str + + +class SqlJoinSchema(AttrsSchema): + class Meta: + target = SqlJoin + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class SqlWhere: + where_clause: str + + +class SqlWhereSchema(AttrsSchema): + class Meta: + target = SqlWhere + register_as_scheme = True + + +@attr.s(auto_attribs=True, kw_only=True) +class Table: + database: str + cluster: str + schema: str + name: str + key: Optional[str] = None + tags: List[Tag] = [] + badges: List[Badge] = [] + table_readers: List[Reader] = [] + description: Optional[str] = None + columns: List[Column] + owners: List[User] = [] + watermarks: List[Watermark] = [] + table_writer: Optional[Application] = None + table_apps: Optional[List[Application]] = None + resource_reports: Optional[List[ResourceReport]] = None + last_updated_timestamp: Optional[int] = None + source: Optional[Source] = None + is_view: Optional[bool] = attr.ib(default=None, converter=default_if_none) + programmatic_descriptions: List[ProgrammaticDescription] = [] + common_joins: Optional[List[SqlJoin]] = None + common_filters: Optional[List[SqlWhere]] = None + + +class TableSchema(AttrsSchema): + class Meta: + target = Table + register_as_scheme = True diff --git a/common/amundsen_common/models/tag.py b/common/amundsen_common/models/tag.py new file mode 100644 index 0000000000..ccda2250d3 --- /dev/null +++ b/common/amundsen_common/models/tag.py @@ -0,0 +1,18 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import attr + +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class Tag: + tag_type: str + tag_name: str + + +class TagSchema(AttrsSchema): + class Meta: + target = Tag + register_as_scheme = True diff --git a/common/amundsen_common/models/user.py b/common/amundsen_common/models/user.py new file mode 100644 index 0000000000..5a874ff7db --- /dev/null +++ b/common/amundsen_common/models/user.py @@ -0,0 +1,90 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Dict + +import attr +from marshmallow import EXCLUDE, ValidationError, validates_schema, pre_load +from marshmallow3_annotations.ext.attrs import AttrsSchema + +""" +TODO: Explore all internationalization use cases and +redesign how User handles names + +TODO - Delete pre processing of the Data +Once all of the upstream services provide a complete User object we will no +longer need to supplement the User objects as done in `preprocess_data` +""" + + +@attr.s(auto_attribs=True, kw_only=True) +class User: + # ToDo (Verdan): Make user_id a required field. + # In case if there is only email, id could be email. + # All the transactions and communication will be handled by ID + user_id: Optional[str] = None + email: Optional[str] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + full_name: Optional[str] = None + display_name: Optional[str] = None + is_active: bool = True + github_username: Optional[str] = None + team_name: Optional[str] = None + slack_id: Optional[str] = None + employee_type: Optional[str] = None + manager_fullname: Optional[str] = None + manager_email: Optional[str] = None + manager_id: Optional[str] = None + role_name: Optional[str] = None + profile_url: Optional[str] = None + other_key_values: Optional[Dict[str, str]] = attr.ib(factory=dict) # type: ignore + # TODO: Add frequent_used, bookmarked, & owned resources + + +class UserSchema(AttrsSchema): + class Meta: + target = User + register_as_scheme = True + unknown = EXCLUDE + + # noinspection PyMethodMayBeStatic + def _str_no_value(self, s: Optional[str]) -> bool: + # Returns True if the given string is None or empty + if not s: + return True + if len(s.strip()) == 0: + return True + return False + + @pre_load + def preprocess_data(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + if self._str_no_value(data.get('user_id')): + data['user_id'] = data.get('email') + + if self._str_no_value(data.get('profile_url')): + data['profile_url'] = '' + if data.get('GET_PROFILE_URL'): + data['profile_url'] = data.get('GET_PROFILE_URL')(data['user_id']) # type: ignore + + first_name = data.get('first_name') + last_name = data.get('last_name') + + if self._str_no_value(data.get('full_name')) and first_name and last_name: + data['full_name'] = f"{first_name} {last_name}" + + if self._str_no_value(data.get('display_name')): + if self._str_no_value(data.get('full_name')): + data['display_name'] = data.get('email') + else: + data['display_name'] = data.get('full_name') + + return data + + @validates_schema + def validate_user(self, data: Dict[str, Any], **kwargs: Any) -> None: + if self._str_no_value(data.get('display_name')): + raise ValidationError('"display_name", "full_name", or "email" must be provided') + + if self._str_no_value(data.get('user_id')): + raise ValidationError('"user_id" or "email" must be provided') diff --git a/common/amundsen_common/py.typed b/common/amundsen_common/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/amundsen_common/tests/__init__.py b/common/amundsen_common/tests/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/tests/fixtures.py b/common/amundsen_common/tests/fixtures.py new file mode 100644 index 0000000000..4fbd040871 --- /dev/null +++ b/common/amundsen_common/tests/fixtures.py @@ -0,0 +1,227 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import string +from typing import Any, List, Optional + +from amundsen_common.models.table import (Application, Column, + ProgrammaticDescription, Stat, Table, + Tag) +from amundsen_common.models.user import User + + +class Fixtures: + """ + These fixtures are useful for creating test objects. For an example usage, check out tests/tests/test_fixtures.py + """ + counter = 1000 + + @staticmethod + def next_int() -> int: + i = Fixtures.counter + Fixtures.counter += 1 + return i + + @staticmethod + def next_string(*, prefix: str = '', length: int = 10) -> str: + astr: str = prefix + \ + ''.join(Fixtures.next_item(items=list(string.ascii_lowercase)) for _ in range(length)) + \ + ('%06d' % Fixtures.next_int()) + return astr + + @staticmethod + def next_range() -> range: + return range(0, Fixtures.next_int() % 5) + + @staticmethod + def next_item(*, items: List[Any]) -> Any: + return items[Fixtures.next_int() % len(items)] + + @staticmethod + def next_database() -> str: + return Fixtures.next_item(items=list(["database1", "database2"])) + + @staticmethod + def next_application(*, application_id: Optional[str] = None) -> Application: + if not application_id: + application_id = Fixtures.next_string(prefix='ap', length=8) + application = Application(application_url=f'https://{application_id}.example.com', + description=f'{application_id} description', + name=application_id.capitalize(), + id=application_id) + return application + + @staticmethod + def next_tag(*, tag_name: Optional[str] = None) -> Tag: + if not tag_name: + tag_name = Fixtures.next_string(prefix='ta', length=8) + return Tag(tag_name=tag_name, tag_type='default') + + @staticmethod + def next_tags() -> List[Tag]: + return sorted([Fixtures.next_tag() for _ in Fixtures.next_range()]) + + @staticmethod + def next_description_source() -> str: + return Fixtures.next_string(prefix='de', length=8) + + @staticmethod + def next_description(*, text: Optional[str] = None, source: Optional[str] = None) -> ProgrammaticDescription: + if not text: + text = Fixtures.next_string(length=20) + if not source: + source = Fixtures.next_description_source() + return ProgrammaticDescription(text=text, source=source) + + @staticmethod + def next_col_type() -> str: + return Fixtures.next_item(items=['varchar', 'int', 'blob', 'timestamp', 'datetime']) + + @staticmethod + def next_column(*, + table_key: str, + sort_order: int, + name: Optional[str] = None) -> Column: + if not name: + name = Fixtures.next_string(prefix='co', length=8) + + return Column(name=name, + description=f'{name} description', + col_type=Fixtures.next_col_type(), + key=f'{table_key}/{name}', + sort_order=sort_order, + stats=[Stat(stat_type='num_rows', + stat_val=f'{Fixtures.next_int() * 100}', + start_epoch=None, + end_epoch=None)]) + + @staticmethod + def next_columns(*, + table_key: str, + randomize_pii: bool = False, + randomize_data_subject: bool = False) -> List[Column]: + return [Fixtures.next_column(table_key=table_key, + sort_order=i + ) for i in Fixtures.next_range()] + + @staticmethod + def next_descriptions() -> List[ProgrammaticDescription]: + return sorted([Fixtures.next_description() for _ in Fixtures.next_range()]) + + @staticmethod + def next_table(table: Optional[str] = None, + cluster: Optional[str] = None, + schema: Optional[str] = None, + database: Optional[str] = None, + tags: Optional[List[Tag]] = None, + application: Optional[Application] = None) -> Table: + """ + Returns a table for testing in the test_database + """ + if not database: + database = Fixtures.next_database() + + if not table: + table = Fixtures.next_string(prefix='tb', length=8) + + if not cluster: + cluster = Fixtures.next_string(prefix='cl', length=8) + + if not schema: + schema = Fixtures.next_string(prefix='sc', length=8) + + if not tags: + tags = Fixtures.next_tags() + + table_key: str = f'{database}://{cluster}.{schema}/{table}' + # TODO: add owners, watermarks, last_udpated_timestamp, source + return Table(database=database, + cluster=cluster, + schema=schema, + name=table, + key=table_key, + tags=tags, + table_writer=application, + table_readers=[], + description=f'{table} description', + programmatic_descriptions=Fixtures.next_descriptions(), + columns=Fixtures.next_columns(table_key=table_key), + is_view=False + ) + + @staticmethod + def next_user(*, user_id: Optional[str] = None, is_active: bool = True) -> User: + last_name = ''.join(Fixtures.next_item(items=list(string.ascii_lowercase)) for _ in range(6)).capitalize() + first_name = Fixtures.next_item(items=['alice', 'bob', 'carol', 'dan']).capitalize() + if not user_id: + user_id = Fixtures.next_string(prefix='us', length=8) + return User(user_id=user_id, + email=f'{user_id}@example.com', + is_active=is_active, + first_name=first_name, + last_name=last_name, + full_name=f'{first_name} {last_name}') + + +def next_application(**kwargs: Any) -> Application: + return Fixtures.next_application(**kwargs) + + +def next_int() -> int: + return Fixtures.next_int() + + +def next_string(**kwargs: Any) -> str: + return Fixtures.next_string(**kwargs) + + +def next_range() -> range: + return Fixtures.next_range() + + +def next_item(**kwargs: Any) -> Any: + return Fixtures.next_item(**kwargs) + + +def next_database() -> str: + return Fixtures.next_database() + + +def next_tag(**kwargs: Any) -> Tag: + return Fixtures.next_tag(**kwargs) + + +def next_tags() -> List[Tag]: + return Fixtures.next_tags() + + +def next_description_source() -> str: + return Fixtures.next_description_source() + + +def next_description(**kwargs: Any) -> ProgrammaticDescription: + return Fixtures.next_description(**kwargs) + + +def next_col_type() -> str: + return Fixtures.next_col_type() + + +def next_column(**kwargs: Any) -> Column: + return Fixtures.next_column(**kwargs) + + +def next_columns(**kwargs: Any) -> List[Column]: + return Fixtures.next_columns(**kwargs) + + +def next_descriptions() -> List[ProgrammaticDescription]: + return Fixtures.next_descriptions() + + +def next_table(**kwargs: Any) -> Table: + return Fixtures.next_table(**kwargs) + + +def next_user(**kwargs: Any) -> User: + return Fixtures.next_user(**kwargs) diff --git a/common/amundsen_common/utils/__init__.py b/common/amundsen_common/utils/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/amundsen_common/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/amundsen_common/utils/atlas.py b/common/amundsen_common/utils/atlas.py new file mode 100644 index 0000000000..14d5b474ea --- /dev/null +++ b/common/amundsen_common/utils/atlas.py @@ -0,0 +1,295 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import abc +import re +from typing import Any, Dict, Optional, Set + + +class AtlasStatus: + ACTIVE = "ACTIVE" + DELETED = "DELETED" + + +class AtlasCommonParams: + qualified_name = 'qualifiedName' + guid = 'guid' + attributes = 'attributes' + relationships = 'relationshipAttributes' + uri = 'entityUri' + type_name = 'typeName' + unique_attributes = 'uniqueAttributes' + created_timestamp = 'createdTimestamp' + last_modified_timestamp = 'lastModifiedTimestamp' + + +class AtlasCommonTypes: + bookmark = 'Bookmark' + user = 'User' + reader = 'Reader' + cluster = 'Cluster' + application = 'Application' + data_set = 'DataSet' + + # These are just `virtual` types which do not actually exist in Atlas. + # We use those constant values to distinguish Atlas Python Client methods which should be used for populating + # such data. + # Tags are published using Glossary API, badges using Classification API. Other entities are published using regular + # Entity API. + tag = 'Tag' + badge = 'Badge' + resource_report = 'Report' + + +class AtlasTableTypes: + table = 'Table' + column = 'Column' + database = 'Database' + schema = 'Schema' + source = 'Source' + watermark = 'TablePartition' + process = 'LineageProcess' + + +class AtlasDashboardTypes: + metadata = 'Dashboard' + group = 'DashboardGroup' + query = 'DashboardQuery' + chart = 'DashboardChart' + execution = 'DashboardExecution' + + +class AtlasKey(abc.ABC): + """ + Class for unification of entity keys between Atlas and Amundsen ecosystems. + + Since Atlas can be populated both by tools from 'Atlas world' (like Apache Atlas Hive hook/bridge) and Amundsen + Databuilder (and each of the approach has a different way to render unique identifiers) we need such class + to serve as unification layer. + """ + + def __init__(self, raw_id: str, database: Optional[str] = None) -> None: + self._raw_identifier = raw_id + self._database = database + + @property + def is_qualified_name(self) -> bool: + """ + Property assessing whether raw_id is qualified name. + + :returns: - + """ + if self.atlas_qualified_name_regex.match(self._raw_identifier): + return True + else: + return False + + @property + def is_amundsen_key(self) -> bool: + """ + Property assessing whether raw_id is amundsen key. + + :returns: - + """ + if self.amundsen_key_regex.match(self._raw_identifier): + return True + else: + return False + + def get_details(self) -> Dict[str, str]: + """ + Collect as many details from key (either qn or amundsen key) + + :returns: dictionary of entity properties derived from key + """ + if self.is_qualified_name: + return self._get_details_from_qualified_name() + elif self.is_amundsen_key: + return self._get_details_from_key() + else: + raise ValueError(f'Value is neither valid qualified name nor amundsen key: {self._raw_identifier}') + + def _get_details(self, pattern: Any) -> Dict[str, str]: + """ + Helper function collecting data from regex match + + :returns: dictionary of matched regex groups with their values + """ + try: + result = pattern.match(self._raw_identifier).groupdict() + + return result + except KeyError: + raise KeyError + + def _get_details_from_qualified_name(self) -> Dict[str, str]: + """ + Collect as many details from qualified name + + :returns: dictionary of entity properties derived from qualified name + """ + try: + return self._get_details(self.atlas_qualified_name_regex) + except KeyError: + raise ValueError(f'This is not valid qualified name: {self._raw_identifier}') + + def _get_details_from_key(self) -> Dict[str, str]: + """ + Collect as many details from amundsen key + + :returns: dictionary of entity properties derived from amundsen key + """ + try: + return self._get_details(self.amundsen_key_regex) + except KeyError: + raise ValueError(f'This is not valid qualified name: {self._raw_identifier}') + + @property + @abc.abstractmethod + def atlas_qualified_name_regex(self) -> Any: + """ + Regex for validating qualified name (and collecting details from qn parts) + + :returns: - + """ + pass + + @property + @abc.abstractmethod + def amundsen_key_regex(self) -> Any: + """ + Regex for validating amundsen key (and collecting details from key parts) + + :returns: - + """ + pass + + @property + @abc.abstractmethod + def qualified_name(self) -> str: + """ + Properly formatted qualified name + + :returns: - + """ + pass + + @property + @abc.abstractmethod + def amundsen_key(self) -> str: + """ + Properly formetted amundsen key + + :returns: - + """ + pass + + @property + def native_atlas_entity_types(self) -> Set[str]: + """ + Atlas can be populated using two approaches: + 1. Using Atlas-provided bridge/hook tools to ingest data in push manner (like Atlas Hive Hook) + 2. Using Amundsen-provided databuilder framework in pull manner + + Since Atlas-provided tools follow different approach for rendering qualified name than databuilder does, + to provide compatibility for both approaches we need to act differently depending whether the table entity + was loaded by Atlas-provided or Amundsen-provided tools. We distinguish them by entity type - in Atlas the + naming convention assumes '_table' suffix in entity name while Amundsen does not have such suffix. + + If the entity_type (database in Amundsen lingo) is one of the values from this property, we treat it like + it was provided by Atlas and follow Atlas qualified name convention. + + If the opposite is true - we treat it like it was provided by Amundsen Databuilder, use generic entity types + and follow Amundsen key name convention. + """ + return {'hive_table'} + + @property + def entity_type(self) -> str: + if self.is_qualified_name: + return self._database or '' + else: + return self.get_details()['database'] \ + if self.get_details()['database'] in self.native_atlas_entity_types else 'Table' + + +class AtlasTableKey(AtlasKey): + @property + def atlas_qualified_name_regex(self) -> Any: + return re.compile(r'^(?P.*?)\.(?P.*)@(?P.*?)$', re.X) + + @property + def amundsen_key_regex(self) -> Any: + return re.compile(r'^(?P.*?)://(?P.*)\.(?P.*?)\/(?P
.*?)$', re.X) + + @property + def qualified_name(self) -> str: + if not self.is_qualified_name and self.get_details()['database'] in self.native_atlas_entity_types: + spec = self._get_details_from_key() + + schema = spec['schema'] + table = spec['table'] + cluster = spec['cluster'] + + return f'{schema}.{table}@{cluster}' + else: + return self._raw_identifier + + @property + def amundsen_key(self) -> str: + if self.is_qualified_name: + spec = self._get_details_from_qualified_name() + + schema = spec['schema'] + table = spec['table'] + cluster = spec['cluster'] + + return f'{self._database}://{cluster}.{schema}/{table}' + elif self.is_amundsen_key: + return self._raw_identifier + else: + raise ValueError(f'Value is neither qualified name nor amundsen key: {self._raw_identifier}') + + +class AtlasColumnKey(AtlasKey): + @property + def atlas_qualified_name_regex(self) -> Any: + return re.compile(r'^(?P.*?)\.(?P
.*?)\.(?P.*?)@(?P.*?)$', re.X) + + @property + def amundsen_key_regex(self) -> Any: + return re.compile(r'^(?P.*?)://(?P.*)\.(?P.*?)\/(?P
.*?)\/(?P.*)$', + re.X) + + @property + def qualified_name(self) -> str: + if self.is_amundsen_key: + spec = self._get_details_from_key() + + schema = spec['schema'] + table = spec['table'] + cluster = spec['cluster'] + column = spec['column'] + + return f'{schema}.{table}.{column}@{cluster}' + elif self.is_qualified_name: + return self._raw_identifier + else: + raise ValueError(f'Value is neither qualified name nor amundsen key: {self._raw_identifier}') + + @property + def amundsen_key(self) -> str: + if self.is_qualified_name: + spec = self._get_details_from_qualified_name() + + schema = spec['schema'] + table = spec['table'] + cluster = spec['cluster'] + column = spec['column'] + + source = self._database.replace('column', 'table') if self._database else '' + + return f'{source}://{cluster}.{schema}/{table}/{column}' + elif self.is_amundsen_key: + return self._raw_identifier + else: + raise ValueError(f'Value is neither qualified name nor amundsen key: {self._raw_identifier}') diff --git a/common/index.html b/common/index.html new file mode 100644 index 0000000000..4603919ad1 --- /dev/null +++ b/common/index.html @@ -0,0 +1,1475 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + Overview - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+ +
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/common/requirements-dev.txt b/common/requirements-dev.txt new file mode 100644 index 0000000000..fbdaf8e2d7 --- /dev/null +++ b/common/requirements-dev.txt @@ -0,0 +1,22 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +# Common dependencies for code quality control (testing, linting, static checks, etc.) --------------------------------- + +flake8>=3.9.2 +flake8-tidy-imports>=4.3.0 +isort[colors]~=5.8.0 +mock>=4.0.3 +mypy>=1.9.0 +pytest>=6.2.4 +pytest-cov>=2.12.0 +pytest-env>=0.6.2 +pytest-mock>=3.6.1 +typed-ast>=1.4.3 +pyspark==3.0.1 +types-mock>=5.1.0.3 +types-protobuf>=4.24.0.4 +types-python-dateutil>=2.8.19.14 +types-pytz>=2023.3.1.1 +types-requests<2.31.0.7 +types-setuptools>=69.0.0.0 diff --git a/common/setup.cfg b/common/setup.cfg new file mode 100644 index 0000000000..abd550187a --- /dev/null +++ b/common/setup.cfg @@ -0,0 +1,58 @@ +[flake8] +format = pylint +exclude = .svc,CVS,.bzr,.hg,.git,__pycache__,venv,.venv +max-complexity = 10 +max-line-length = 120 + +# flake8-tidy-imports rules +banned-modules = + dateutil.parser = Use `ciso8601` instead + flask.ext.restful = Use `flask_restful` + flask.ext.script = Use `flask_script` + flask_restful.reqparse = Use `marshmallow` for request/response validation + haversine = Use `from fast_distance import haversine` + py.test = Use `pytest` + python-s3file = Use `boto` + +[pep8] +max-line-length = 79 + +[tool:pytest] +addopts = --cov=amundsen_common --cov-fail-under=70 --cov-report=term-missing:skip-covered --cov-report=xml --cov-report=html -vvv + +[coverage:run] +omit = */models/* +branch = True + +[coverage:xml] +output = build/coverage.xml + +[coverage:html] +directory = build/coverage_html + +[mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true + +[semantic_release] +version_variable = "./setup.py:__version__" +upload_to_pypi = true +upload_to_release = true +commit_subject = New release for {version} +commit_message = Signed-off-by: github-actions +commit_author = github-actions + +[mypy-marshmallow.*] +ignore_missing_imports = true + +[mypy-marshmallow3_annotations.*] +ignore_missing_imports = true + +[mypy-setuptools.*] +ignore_missing_imports = true + +[mypy-tests.*] +disallow_untyped_defs = false diff --git a/common/setup.py b/common/setup.py new file mode 100644 index 0000000000..a864d57340 --- /dev/null +++ b/common/setup.py @@ -0,0 +1,56 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import os + +from setuptools import find_packages, setup + +__version__ = '0.32.0' + + +requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements-dev.txt') +with open(requirements_path) as requirements_file: + requirements_dev = requirements_file.readlines() + + +setup( + name='amundsen-common', + version=__version__, + description='Common code library for Amundsen', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url='https://github.com/amundsen-io/amundsen/tree/main/common', + maintainer='Amundsen TSC', + maintainer_email='amundsen-tsc@lists.lfai.foundation', + packages=find_packages(exclude=['tests*']), + install_requires=[ + # Packages in here should rarely be pinned. This is because these + # packages (at the specified version) are required for project + # consuming this library. By pinning to a specific version you are the + # number of projects that can consume this or forcing them to + # upgrade/downgrade any dependencies pinned here in their project. + # + # Generally packages listed here are pinned to a major version range. + # + # e.g. + # Python FooBar package for foobaring + # pyfoobar>=1.0, <2.0 + # + # This will allow for any consuming projects to use this library as + # long as they have a version of pyfoobar equal to or greater than 1.x + # and less than 2.x installed. + 'Flask>=2.2.5', + 'attrs>=19.0.0', + 'marshmallow>=3.0', + 'marshmallow3-annotations>=1.1.0' + ], + extras_require={ + 'all': requirements_dev + }, + python_requires=">=3.8", + package_data={'amundsen_common': ['py.typed']}, + classifiers=[ + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + ], +) diff --git a/common/tests/__init__.py b/common/tests/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/tests/__init__.py b/common/tests/tests/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/tests/test_fixtures.py b/common/tests/tests/test_fixtures.py new file mode 100644 index 0000000000..a35c6c62d6 --- /dev/null +++ b/common/tests/tests/test_fixtures.py @@ -0,0 +1,106 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from amundsen_common.tests.fixtures import (next_application, next_col_type, + next_columns, next_database, + next_description, + next_description_source, + next_descriptions, next_int, next_item, + next_range, next_string, next_table, + next_tag, next_tags, next_user) +from amundsen_common.models.table import Column, ProgrammaticDescription, Stat + + +class TestFixtures(unittest.TestCase): + # tests are numbered to ensure they execute in order + def test_00_next_int(self) -> None: + self.assertEqual(1000, next_int()) + + def test_01_next_string(self) -> None: + self.assertEqual('nopqrstuvw001011', next_string()) + + def test_02_next_string(self) -> None: + self.assertEqual('foo_yzabcdefgh001022', next_string(prefix='foo_')) + + def test_03_next_string(self) -> None: + self.assertEqual('jklm001027', next_string(length=4)) + + def test_04_next_string(self) -> None: + self.assertEqual('bar_opqr001032', next_string(prefix='bar_', length=4)) + + def test_05_next_range(self) -> None: + self.assertEqual(3, len(next_range())) + + def test_06_next_item(self) -> None: + self.assertEqual('c', next_item(items=['a', 'b', 'c'])) + + def test_07_next_database(self) -> None: + self.assertEqual('database2', next_database()) + + def test_08_next_application(self) -> None: + app = next_application() + self.assertEqual('Apwxyzabcd001044', app.name) + self.assertEqual('apwxyzabcd001044', app.id) + self.assertEqual('https://apwxyzabcd001044.example.com', app.application_url) + + def test_09_next_application(self) -> None: + app = next_application(application_id='foo') + self.assertEqual('Foo', app.name) + self.assertEqual('foo', app.id) + self.assertEqual('https://foo.example.com', app.application_url) + + def test_10_next_tag(self) -> None: + tag = next_tag() + self.assertEqual('tafghijklm001053', tag.tag_name) + self.assertEqual('default', tag.tag_type) + + def test_11_next_tags(self) -> None: + tags = next_tags() + self.assertEqual(4, len(tags)) + self.assertEqual(['tahijklmno001081', + 'tapqrstuvw001063', + 'taqrstuvwx001090', + 'tayzabcdef001072'], [tag.tag_name for tag in tags]) + + def test_12_next_description_source(self) -> None: + self.assertEqual('dezabcdefg001099', next_description_source()) + + def test_13_next_description(self) -> None: + self.assertEqual(ProgrammaticDescription(text='ijklmnopqrstuvwxyzab001120', source='dedefghijk001129'), + next_description()) + + def test_14_next_col_type(self) -> None: + self.assertEqual('varchar', next_col_type()) + + def test_15_just_execute_next_columns(self) -> None: + columns = next_columns(table_key='not_important') + self.assertEqual(1, len(columns)) + self.assertEqual([Column(name='coopqrstuv001140', key='not_important/coopqrstuv001140', + description='coopqrstuv001140 description', col_type='int', + sort_order=0, stats=[Stat(stat_type='num_rows', stat_val='114200', + start_epoch=None, end_epoch=None)]) + ], columns) + + def test_16_just_execute_next_descriptions(self) -> None: + descs = next_descriptions() + self.assertEqual(3, len(descs)) + self.assertEqual([ + ProgrammaticDescription(source='dedefghijk001233', text='ijklmnopqrstuvwxyzab001224'), + ProgrammaticDescription(source='devwxyzabc001173', text='abcdefghijklmnopqrst001164'), + ProgrammaticDescription(source='dezabcdefg001203', text='efghijklmnopqrstuvwx001194')], descs) + + def test_17_just_execute_next_table(self) -> None: + table = next_table() + self.assertEqual(2, len(table.columns)) + self.assertEqual('tbnopqrstu001243', table.name) + self.assertEqual('database1://clwxyzabcd001252.scfghijklm001261/tbnopqrstu001243', table.key) + + def test_18_next_user(self) -> None: + user = next_user() + self.assertEqual('Jklmno', user.last_name) + self.assertEqual('Bob', user.first_name) + self.assertEqual('usqrstuvwx001350', user.user_id) + self.assertEqual('usqrstuvwx001350@example.com', user.email) + self.assertEqual(True, user.is_active) diff --git a/common/tests/unit/__init__.py b/common/tests/unit/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/unit/log/__init__.py b/common/tests/unit/log/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/unit/log/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/unit/log/test_action_log.py b/common/tests/unit/log/test_action_log.py new file mode 100644 index 0000000000..d020da12a7 --- /dev/null +++ b/common/tests/unit/log/test_action_log.py @@ -0,0 +1,83 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import socket +import unittest +from contextlib import contextmanager +from typing import Generator, Any + +import flask + +from amundsen_common.log import action_log, action_log_callback +from amundsen_common.log.action_log import action_logging, get_epoch_millisec + +app = flask.Flask(__name__) + + +class ActionLogTest(unittest.TestCase): + + def test_metrics_build(self) -> None: + # with patch.object(current_app, 'config'): + with app.test_request_context(): + func_name = 'search' + metrics = action_log._build_metrics(func_name, 'dummy', 777, foo='bar') + + expected = { + 'command': 'search', + 'host_name': socket.gethostname(), + 'pos_args_json': '["dummy", 777]', + 'keyword_args_json': '{"foo": "bar"}', + 'user': 'UNKNOWN', + } + + for k, v in expected.items(): + self.assertEquals(v, metrics.get(k)) + + self.assertTrue(metrics.get('start_epoch_ms') <= get_epoch_millisec()) # type: ignore + + def test_fail_function(self) -> None: + """ + Actual function is failing and fail needs to be propagated. + :return: + """ + with app.test_request_context(), self.assertRaises(NotImplementedError): + fail_func() + + def test_success_function(self) -> None: + """ + Test success function but with failing callback. + In this case, failure should not propagate. + :return: + """ + with app.test_request_context(), fail_action_logger_callback(): + success_func() + + +@contextmanager +def fail_action_logger_callback() -> Generator[Any, Any, Any]: + """ + Adding failing callback and revert it back when closed. + :return: + """ + tmp = action_log_callback.__pre_exec_callbacks[:] + + def fail_callback(_action_callback: Any) -> None: + raise NotImplementedError + + action_log_callback.register_pre_exec_callback(fail_callback) + yield + action_log_callback.__pre_exec_callbacks = tmp + + +@action_logging +def fail_func() -> None: + raise NotImplementedError + + +@action_logging +def success_func() -> None: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/common/tests/unit/log/test_http_header_caller_retrieval.py b/common/tests/unit/log/test_http_header_caller_retrieval.py new file mode 100644 index 0000000000..15ed7b1b4a --- /dev/null +++ b/common/tests/unit/log/test_http_header_caller_retrieval.py @@ -0,0 +1,26 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import flask +from mock import patch +from mock import MagicMock + +from amundsen_common.log import http_header_caller_retrieval +from amundsen_common.log.http_header_caller_retrieval import HttpHeaderCallerRetrieval + +app = flask.Flask(__name__) + + +class ActionLogTest(unittest.TestCase): + def test(self) -> None: + with app.test_request_context(), \ + patch.object(http_header_caller_retrieval, 'request', new=MagicMock()) as mock_request: + mock_request.headers.get.return_value = 'foo' + actual = HttpHeaderCallerRetrieval().get_caller() + self.assertEqual(actual, 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/common/tests/unit/models/__init__.py b/common/tests/unit/models/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/unit/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/unit/models/test_user.py b/common/tests/unit/models/test_user.py new file mode 100644 index 0000000000..4e09be9bc0 --- /dev/null +++ b/common/tests/unit/models/test_user.py @@ -0,0 +1,100 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import flask +import unittest + +from marshmallow import ValidationError + +from amundsen_common.models.user import UserSchema + +app = flask.Flask(__name__) + + +class UserTest(unittest.TestCase): + def test_set_user_id_from_email(self) -> None: + """ + Deserialization and serialization sets user_id from email if no user_id + :return: + """ + with app.test_request_context(): + self.assertEqual( + UserSchema().load({"email": "test@test.com"}).user_id, "test@test.com" + ) + + def test_set_display_name_from_full_name(self) -> None: + """ + Deserialization and serialization sets display_name from full_name if no display_name and + full_name is a non-empty string + :return: + """ + test_user = {"email": "test@test.com", "full_name": "Test User"} + with app.test_request_context(): + self.assertEqual(UserSchema().load(test_user).display_name, "Test User") + + def test_set_display_name_from_email(self) -> None: + """ + Deserialization and serialization sets display_name from email if no display_name and + full_name is None + :return: + """ + with app.test_request_context(): + self.assertEqual( + UserSchema().load({"email": "test@test.com"}).display_name, + "test@test.com", + ) + + def test_set_display_name_from_email_if_full_name_empty(self) -> None: + """ + Deserialization and serialization sets display_name from email if no display_name and + full_name is '' + :return: + """ + test_user = {"email": "test@test.com", "full_name": ""} + with app.test_request_context(): + self.assertEqual(UserSchema().load(test_user).display_name, "test@test.com") + + def test_profile_url(self) -> None: + """ + Deserialization and serialization sets profile_url from function defined at GET_PROFILE_URL + if no profile_url provided' + :return: + """ + test_user = {"email": "test@test.com", "GET_PROFILE_URL": lambda _: "testUrl"} + + with app.test_request_context(): + self.assertEqual(UserSchema().load(test_user).profile_url, "testUrl") + + def test_raise_error_if_no_display_name(self) -> None: + """ + Error is raised if deserialization of Dict will not generate a display_name + :return: + """ + with app.test_request_context(): + with self.assertRaises(ValidationError): + UserSchema().load({}) + + def test_raise_error_if_no_user_id(self) -> None: + """ + Error is raised if deserialization of Dict will not generate a user_id + :return: + """ + with app.test_request_context(): + with self.assertRaises(ValidationError): + UserSchema().load({"display_name": "Test User"}) + + def test_str_no_value(self) -> None: + """ + Test _str_no_value returns True for a string of spaces + :return: + """ + self.assertEqual(UserSchema()._str_no_value(" "), True) + + def test_extra_key_does_not_raise(self) -> None: + """ + Handle extra keys in the user data + :return: + """ + test_user = {"email": "test@test.com", "foo": "bar"} + with app.test_request_context(): + self.assertEqual(UserSchema().load(test_user).email, "test@test.com") diff --git a/common/tests/unit/utils/__init__.py b/common/tests/unit/utils/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/common/tests/unit/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/common/tests/unit/utils/test_atlas_utils.py b/common/tests/unit/utils/test_atlas_utils.py new file mode 100644 index 0000000000..2d665acf59 --- /dev/null +++ b/common/tests/unit/utils/test_atlas_utils.py @@ -0,0 +1,235 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import unittest + +from amundsen_common.utils.atlas import AtlasColumnKey, AtlasTableKey + + +class TestAtlasTableKey(unittest.TestCase): + def test_table_key(self) -> None: + params = [ + ('hive_table://gold.database_name/table_name', + None, + 'hive_table://gold.database_name/table_name', + 'database_name.table_name@gold', + dict(database='hive_table', cluster='gold', schema='database_name', table='table_name'), + False, + True), + ('database_name.table_name@gold', + 'hive_table', + 'hive_table://gold.database_name/table_name', + 'database_name.table_name@gold', + dict(cluster='gold', schema='database_name', table='table_name'), + True, + False) + ] + + for key, database, amundsen_key, qualified_name, details, is_key_qualified_name, is_key_amundsen_key in params: + with self.subTest(): + result = AtlasTableKey(key, database=database) + + self.assertEqual(result.amundsen_key, amundsen_key) + self.assertEqual(result.qualified_name, qualified_name) + self.assertEqual(result.is_qualified_name, is_key_qualified_name) + self.assertEqual(result.is_amundsen_key, is_key_amundsen_key) + self.assertDictEqual(result.get_details(), details) + + def test_table_key_amundsen_key_validation(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name', True), + ('hive_table://cluster_name.with.dot.db_name/table_name', True), + ('db_name.table_name@cluster_name', False) + ] + + for key, is_amundsen_key in params: + with self.subTest(f'Amundsen key validation for key: {key}'): + result = AtlasTableKey(key) + + self.assertEqual(is_amundsen_key, result.is_amundsen_key) + + def test_table_key_qualified_name_validation(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name', False), + ('hive_table://cluster_name.db_name/table_name', False), + ('db_name.table_name@cluster_name', True), + ('db.table@cluster.dot', True) + ] + + for key, is_amundsen_key in params: + with self.subTest(f'Amundsen qualified name validation for key: {key}'): + result = AtlasTableKey(key) + + self.assertEqual(is_amundsen_key, result.is_qualified_name) + + def test_table_key_qualified_name_from_amundsen_key(self) -> None: + params = [ + ('hive_table://cluster_name.db_name/table_name', 'db_name.table_name@cluster_name'), + ('hive://cluster_name.db_name/table_name', 'hive://cluster_name.db_name/table_name') + ] + + for key, qn in params: + with self.subTest(f'Test rendering qualified name from amundsen key: {key}'): + result = AtlasTableKey(key) + + self.assertEqual(qn, result.qualified_name) + + def test_table_key_amundsen_key_from_qualified_name(self) -> None: + params = [ + ('db_name.table_name@cluster_name', 'hive', 'hive://cluster_name.db_name/table_name'), + ('db_name.table_name@cluster_name.dot', 'hive_table', 'hive_table://cluster_name.dot.db_name/table_name') + ] + + for qn, database, key in params: + with self.subTest(f'Test rendering amundsen key from qualified name: {qn}'): + result = AtlasTableKey(qn, database=database) + + self.assertEqual(key, result.amundsen_key) + + def test_table_key_details_from_amundsen_key(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name', + dict(database='hive', cluster='cluster_name', schema='db_name', table='table_name')), + ('hive_table://cluster_name.dot.db_name/table_name', + dict(database='hive_table', cluster='cluster_name.dot', schema='db_name', table='table_name')) + ] + + for key, details in params: + with self.subTest(f'Test extract details from amundsen key: {key}'): + result = AtlasTableKey(key) + + self.assertEqual(details, result.get_details()) + + def test_table_key_details_from_qualified_name(self) -> None: + params = [ + ('db_name.table_name@cluster_name', + dict(cluster='cluster_name', schema='db_name', table='table_name')), + ('db_name.table_name@cluster_name.dot', + dict(cluster='cluster_name.dot', schema='db_name', table='table_name')) + ] + + for qn, details in params: + with self.subTest(f'Test extract details from qualified name: {qn}'): + result = AtlasTableKey(qn) + + self.assertEqual(details, result.get_details()) + + +class TestAtlasColumnKey(unittest.TestCase): + def test_table_column_key(self) -> None: + params = [ + ('hive_table://gold.database_name/table_name/column_name', + None, + 'hive_table://gold.database_name/table_name/column_name', + 'database_name.table_name.column_name@gold', + dict(database='hive_table', cluster='gold', schema='database_name', table='table_name', + column='column_name'), + False, + True), + ('database_name.table_name.column_name@gold', + 'hive_table', + 'hive_table://gold.database_name/table_name/column_name', + 'database_name.table_name.column_name@gold', + dict(cluster='gold', schema='database_name', table='table_name', column='column_name'), + True, + False) + ] + + for key, database, amundsen_key, qualified_name, details, is_key_qualified_name, is_key_amundsen_key in params: + with self.subTest(): + result = AtlasColumnKey(key, database=database) + + self.assertEqual(result.amundsen_key, amundsen_key) + self.assertEqual(result.qualified_name, qualified_name) + self.assertEqual(result.is_qualified_name, is_key_qualified_name) + self.assertEqual(result.is_amundsen_key, is_key_amundsen_key) + self.assertDictEqual(result.get_details(), details) + + def test_table_column_key_amundsen_key_validation(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name/column_name', True), + ('hive_table://cluster_name.with.dot.db_name/table_name/column_name', True), + ('db_name.table_name.column_name@cluster_name', False), + ('db.table.column@cluster.dot', False) + ] + + for key, is_amundsen_key in params: + with self.subTest(f'Amundsen key validation for key: {key}'): + result = AtlasColumnKey(key) + + self.assertEqual(is_amundsen_key, result.is_amundsen_key) + + def test_table_column_key_qualified_name_validation(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name/column_name', False), + ('hive_table://cluster_name.with.dot.db_name/table_name/column_name', False), + ('db_name.table_name.column_name@cluster_name', True), + ('db.table.column@cluster.dot', True) + ] + + for key, is_amundsen_key in params: + with self.subTest(f'Amundsen qualified name validation for key: {key}'): + result = AtlasColumnKey(key) + + self.assertEqual(is_amundsen_key, result.is_qualified_name) + + def test_table_column_key_qualified_name_from_amundsen_key(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name/column_name', + 'db_name.table_name.column_name@cluster_name'), + ('hive_table://cluster_name.dot.db_name/table_name/column_name', + 'db_name.table_name.column_name@cluster_name.dot') + ] + + for key, qn in params: + with self.subTest(f'Test rendering qualified name from amundsen key: {key}'): + result = AtlasColumnKey(key) + + self.assertEqual(qn, result.qualified_name) + + def test_table_column_key_amundsen_key_from_qualified_name(self) -> None: + params = [ + ('db_name.table_name.column_name@cluster_name', 'hive', + 'hive://cluster_name.db_name/table_name/column_name'), + ('db_name.table_name.column_name.dot@cluster_name.dot', 'hive_table', + 'hive_table://cluster_name.dot.db_name/table_name/column_name.dot') + ] + + for qn, database, key in params: + with self.subTest(f'Test rendering amundsen key from qualified name: {qn}'): + result = AtlasColumnKey(qn, database=database) + + self.assertEqual(key, result.amundsen_key) + + def test_table_column_key_details_from_amundsen_key(self) -> None: + params = [ + ('hive://cluster_name.db_name/table_name/column_name', + dict(database='hive', cluster='cluster_name', schema='db_name', table='table_name', + column='column_name')), + ('hive_table://cluster_name.dot.db_name/table_name/column_name.dot', + dict(database='hive_table', cluster='cluster_name.dot', schema='db_name', table='table_name', + column='column_name.dot')) + ] + + for key, details in params: + with self.subTest(f'Test extract details from amundsen key: {key}'): + result = AtlasColumnKey(key) + + self.assertEqual(details, result.get_details()) + + def test_table_column_key_details_from_qualified_name(self) -> None: + params = [ + ('db_name.table_name.column_name@cluster_name', + dict(cluster='cluster_name', schema='db_name', table='table_name', column='column_name')), + ('db_name.table_name.column_name.dot@cluster_name.dot', + dict(cluster='cluster_name.dot', schema='db_name', table='table_name', column='column_name.dot')) + ] + + for qn, details in params: + with self.subTest(f'Test extract details from qualified name: {qn}'): + result = AtlasColumnKey(qn) + + self.assertEqual(details, result.get_details()) + + +if __name__ == '__main__': + unittest.main() diff --git a/css/app.css b/css/app.css new file mode 100644 index 0000000000..eb101548a9 --- /dev/null +++ b/css/app.css @@ -0,0 +1,11 @@ +@import "theme.css"; + +/* Splits a long line descriptions in tables in to multiple lines */ +.wy-table-responsive table td, .wy-table-responsive table th { + white-space: normal !important; +} + +/* align multi line csv table columns */ +table.docutils div.line-block { + margin-left: 0; +} diff --git a/databuilder/CHANGELOG/index.html b/databuilder/CHANGELOG/index.html new file mode 100644 index 0000000000..ee6002a08e --- /dev/null +++ b/databuilder/CHANGELOG/index.html @@ -0,0 +1,1542 @@ + + + + + + + + + + + + + + + + + + + + + + + + CHANGELOG - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + + + +

4.4.0

+

Features

+ +

Bugfixes

+ +

Pre 4.4.0 changes

+

Feature

+
    +
  • Add support for tags based on atlas terms (#466) (cc1caf3)
  • +
  • Make DescriptionMetadata inherit from GraphSerializable (#461) (7f095fb)
  • +
  • Add TableSerializable and mysql_serializer (#459) (4bb4452)
  • +
  • Neptune Data builder Integration (#438) (303e8aa)
  • +
  • Add config key for connect_arg for SqlAlchemyExtractor (#434) (7f3be0f)
  • +
  • Vertica metadata extractor (#433) (f4bd207)
  • +
  • Multi-yield transformers (#396) (49ae0ed)
  • +
  • Atlas_search_extractor | :tada: Initial commit. (#415) (8c63307)
  • +
  • Sample Feast job with ES publisher (#425) (453a18b)
  • +
  • Adding CsvTableBadgeExtractor (#417) (592ee71)
  • +
  • Feast extractor (#414) (2343a90)
  • +
  • Adding first pass of delta lake metadata extractor as well as a sample script on how it would be used. (#351) (e8679aa)
  • +
  • Use parameters to allow special characters in neo4j cypher statement (#382) (6fd5035)
  • +
  • Column level badges cont. (#381) (af4b512)
  • +
  • Support dashboard chart in search (#383) (6cced36)
  • +
  • Column level badges (#375) (8beee3e)
  • +
  • Added Dremio extractor (#377) (63f239f)
  • +
  • Add an extractor for pulling user information from BambooHR (#369) (6802ab1)
  • +
  • Add sample_glue_loader script (#366) (fa3f11b)
  • +
  • Parameterize Snowflake Schema in Snowflake Metadata Extractor (#361) (aa4416c)
  • +
  • Mode Batch dashboard charrt API (#362) (87213c5)
  • +
  • Create a RedshiftMetadataExtractor that supports late binding views (#356) (4113cfd)
  • +
  • Add MySQL sample data loader (#359) (871a176)
  • +
  • Add Snowflake table last updated timestamp extractor (#348) (0bac11b)
  • +
  • Add Tableau dashboard metadata extractors (#333) (46207ee)
  • +
  • Add github actions for databuilder (#336) (236e7de)
  • +
  • Allow hive sql to be provided as config (#312) (8075a6c)
  • +
  • Enhance glue extractor (#306) (faa795c)
  • +
  • Add RedashDashboardExtractor for extracting dashboards from redash.io (#300) (f1b0dfa)
  • +
  • Add a transformer that adds tags to all tables created in a job (#287) (d2f4bd3)
  • +
+

Fix

+
    +
  • Add support for Tableau multi-site deployment (#463) (e35af58)
  • +
  • Avoid error by checking for existence before close. (#454) (5cd0dc8)
  • +
  • Correct config getter (#455) (4b37746)
  • +
  • Close SQL Alchemy connections. (#453) (25124c1)
  • +
  • Add comma between bigquery requirements listings (#452) (027edb9)
  • +
  • Increase the compatibility of id structure between the Databuilder and the Metadata Library (#445) (6a13762)
  • +
  • Move ‘grouped_tables’ into _retrieve_tables (#430) (26a0d0a)
  • +
  • Address PyAthena version (#429) (7157c24)
  • +
  • Add csv badges back in Quickstart (#418) (c0296b7)
  • +
  • Typo in Readme (#424) (29bd72f)
  • +
  • Fix redash dashboard exporter (#422) (fa626f5)
  • +
  • Update the key format of set ‘grouped_tables’ (#421) (4c9e5f7)
  • +
  • Retry loop for exception caused by deadlock on badge node (#404) (9fd1513)
  • +
  • FsNeo4jCSVLoader fails if nodes have disjoint keys (#408) (c07cec9)
  • +
  • Cast dashboard usage to be int (#412) (8bcc489)
  • +
  • Pandas ‘nan’ values (#409) (3a28f46)
  • +
  • Add databuilder missing dependencies (#400) (6718396)
  • +
  • Allow BigQuery Usage Extractor to extract usage for views (#399) (8779229)
  • +
  • Hive metadata extractor not work on postgresql (#394) (2992618)
  • +
  • Issues with inconsistency in case conversion (#388) (9595866)
  • +
  • Update elasticsearch table index mapping (#373) (88c0552)
  • +
  • Fix programmatic source data (#367) (4f5df39)
  • +
  • Update connection string in Snowflake extractor to include wareh… (#357) (a11d206)
  • +
  • Edge case in Snowflake information_schema.last_altered value (#360) (c3e713e)
  • +
  • Correct typo in Snowflake Last Updated extract query (#358) (5c2e98e)
  • +
  • Set Tableau URLs (base + API) via config (#349) (1baec33)
  • +
  • Fix invalid timestamp handling in dashboard transformer (#339) (030ef49)
  • +
  • Update postgres_sample_dag to set table extract job as upstream for elastic search publisher (#340) (c79935e)
  • +
  • deps: Unpin attrs (#332) (86f658d)
  • +
  • Cypher statement param issue in Neo4jStalenessRemovalTask (#307) (0078761)
  • +
  • Added missing job tag key in hive_sample_dag.py (#308) (d6714b7)
  • +
  • Fix sql for missing columns and mysql based dialects (#550) (#305) (4b7b147)
  • +
  • Escape backslashes in Neo4jCsvPublisher (1faa713)
  • +
  • Variable organization in Model URL (#293) (b4c24ef)
  • +
+

Documentation

+ + + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/databuilder/LICENSE b/databuilder/LICENSE new file mode 100644 index 0000000000..a1c70dc855 --- /dev/null +++ b/databuilder/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/databuilder/MANIFEST.in b/databuilder/MANIFEST.in new file mode 100644 index 0000000000..db84e7298b --- /dev/null +++ b/databuilder/MANIFEST.in @@ -0,0 +1,5 @@ +include requirements.txt + +global-include requirements-dev.txt + +recursive-include databuilder/types/atlas/schema * diff --git a/databuilder/Makefile b/databuilder/Makefile new file mode 100644 index 0000000000..2b11aa226a --- /dev/null +++ b/databuilder/Makefile @@ -0,0 +1,30 @@ +clean: + find . -name \*.pyc -delete + find . -name __pycache__ -delete + rm -rf dist/ + +.PHONY: test_unit +test_unit: + python3 -bb -m pytest tests + +lint: + flake8 . + +.PHONY: mypy +mypy: + mypy . + +.PHONY: isort +isort: + isort . + +.PHONY: isort_check +isort_check: + isort ./ --check --diff + +.PHONY: test +test: test_unit lint mypy isort_check + +.PHONY: install_deps +install_deps: + pip3 install -e ".[all]" diff --git a/databuilder/NOTICE b/databuilder/NOTICE new file mode 100644 index 0000000000..395e2ff80e --- /dev/null +++ b/databuilder/NOTICE @@ -0,0 +1,4 @@ +amundsendatabuilder +Copyright 2018-2019 Lyft Inc. + +This product includes software developed at Lyft Inc. diff --git a/databuilder/databuilder/__init__.py b/databuilder/databuilder/__init__.py new file mode 100644 index 0000000000..6f1d8f2fb4 --- /dev/null +++ b/databuilder/databuilder/__init__.py @@ -0,0 +1,75 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc + +from pyhocon import ConfigFactory, ConfigTree + + +class Scoped(object, metaclass=abc.ABCMeta): + _EMPTY_CONFIG = ConfigFactory.from_dict({}) + """ + An interface for class that works with scoped (nested) config. + https://github.com/chimpler/pyhocon + A scoped instance will use config within its scope. This is a way to + distribute configuration to its implementation instead of having it in + one central place. + This is very useful for DataBuilder as it has different components + (extractor, transformer, loader, publisher) and its component itself + could have different implementation. + For example these can be a configuration for two different extractors + "extractor.mysql.url" for MySQLExtractor + "extractor.filesystem.source_path" for FileSystemExtractor + + For MySQLExtractor, if you defined scope as "extractor.mysql", scoped + config will basically reduce it to the config that is only for MySQL. + config.get("extractor.mysql") provides you all the config within + 'extractor.mysql'. By removing outer context from the config, + MySQLExtractor is highly reusable. + """ + + @abc.abstractmethod + def init(self, conf: ConfigTree) -> None: + """ + All scoped instance is expected to be lazily initialized. Means that + __init__ should not have any heavy operation such as service call. + The reason behind is that Databuilder is a code at the same time, + code itself is used as a configuration. For example, you can + instantiate scoped instance with all the parameters already set, + ready to run, and actual execution will be executing init() and + execute. + + :param conf: Typesafe config instance + :return: None + """ + pass + + @abc.abstractmethod + def get_scope(self) -> str: + """ + A scope for the config. Typesafe config supports nested config. + Scope, string, is used to basically peel off nested config + :return: + """ + return '' + + def close(self) -> None: + """ + Anything that needs to be cleaned up after the use of the instance. + :return: None + """ + pass + + @classmethod + def get_scoped_conf(cls, conf: ConfigTree, scope: str) -> ConfigTree: + """ + Convenient method to provide scoped method. + + :param conf: Type safe config instance + :param scope: scope string + :return: Type safe config instance + """ + if not scope: + return Scoped._EMPTY_CONFIG + + return conf.get(scope, Scoped._EMPTY_CONFIG) diff --git a/databuilder/databuilder/callback/__init__.py b/databuilder/databuilder/callback/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/callback/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/callback/call_back.py b/databuilder/databuilder/callback/call_back.py new file mode 100644 index 0000000000..905c70e668 --- /dev/null +++ b/databuilder/databuilder/callback/call_back.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +from typing import List, Optional + +LOGGER = logging.getLogger(__name__) + + +class Callback(object, metaclass=abc.ABCMeta): + """ + A callback interface that expected to fire "on_success" if the operation is successful, else "on_failure" if + operation failed. + """ + + @abc.abstractmethod + def on_success(self) -> None: + """ + A call back method that will be called when operation is successful + :return: None + """ + pass + + @abc.abstractmethod + def on_failure(self) -> None: + """ + A call back method that will be called when operation failed + :return: None + """ + pass + + +def notify_callbacks(callbacks: List[Callback], is_success: bool) -> None: + """ + A Utility method that notifies callback. If any callback fails it will still go through all the callbacks, + and raise the last exception it experienced. + + :param callbacks: + :param is_success: + :return: + """ + + if not callbacks: + LOGGER.info('No callbacks to notify') + return + + LOGGER.info('Notifying callbacks') + + last_exception: Optional[Exception] = None + for callback in callbacks: + try: + if is_success: + callback.on_success() + else: + callback.on_failure() + except Exception as e: + LOGGER.exception('Failed while notifying callback') + last_exception = e + + if last_exception: + raise last_exception diff --git a/databuilder/databuilder/clients/__init__.py b/databuilder/databuilder/clients/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/clients/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/clients/neptune_client.py b/databuilder/databuilder/clients/neptune_client.py new file mode 100644 index 0000000000..44098d16a8 --- /dev/null +++ b/databuilder/databuilder/clients/neptune_client.py @@ -0,0 +1,141 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Callable, Dict, List, Optional, Tuple, +) + +from amundsen_gremlin.neptune_bulk_loader import api as neptune_bulk_loader_api +from boto3.session import Session +from gremlin_python.process.graph_traversal import ( + GraphTraversal, GraphTraversalSource, __, +) +from gremlin_python.process.traversal import T +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped + + +class NeptuneSessionClient(Scoped): + """ + A convenience client for neptune gives functions to perform upserts, deletions and queries with filters. + """ + # What property is used to local nodes and edges by ids + NEPTUNE_HOST_NAME = 'neptune_host_name' + # AWS Region the Neptune cluster is located + AWS_REGION = 'aws_region' + AWS_ACCESS_KEY = 'aws_access_key' + AWS_SECRET_ACCESS_KEY = 'aws_access_secret' + AWS_SESSION_TOKEN = 'aws_session_token' + + WEBSOCKET_OPTIONS = 'websocket_options' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + { + AWS_SESSION_TOKEN: None, + WEBSOCKET_OPTIONS: {}, + } + ) + + def __init__(self) -> None: + self._graph = None + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(NeptuneSessionClient.DEFAULT_CONFIG) + + boto_session = Session( + aws_access_key_id=conf.get_string(NeptuneSessionClient.AWS_ACCESS_KEY, default=None), + aws_secret_access_key=conf.get_string(NeptuneSessionClient.AWS_SECRET_ACCESS_KEY, default=None), + aws_session_token=conf.get_string(NeptuneSessionClient.AWS_SESSION_TOKEN, default=None), + region_name=conf.get_string(NeptuneSessionClient.AWS_REGION, default=None) + ) + self._neptune_host = conf.get_string(NeptuneSessionClient.NEPTUNE_HOST_NAME) + neptune_uri = "wss://{host}/gremlin".format( + host=self._neptune_host + ) + source_factory = neptune_bulk_loader_api.get_neptune_graph_traversal_source_factory( + neptune_url=neptune_uri, + session=boto_session + ) + self._graph = source_factory() + + def get_scope(self) -> str: + return 'neptune.client' + + def get_graph(self) -> GraphTraversalSource: + return self._graph + + def upsert_node(self, node_id: str, node_label: str, node_properties: Dict[str, Any]) -> None: + create_traversal = __.addV(node_label).property(T.id, node_id) + node_traversal = self.get_graph().V().has(T.id, node_id). \ + fold().coalesce(__.unfold(), create_traversal) + + node_traversal = NeptuneSessionClient.update_entity_properties_on_traversal(node_traversal, node_properties) + node_traversal.next() + + def upsert_edge( + self, + start_node_id: str, + end_node_id: str, + edge_id: str, + edge_label: str, + edge_properties: Dict[str, Any] + ) -> None: + create_traversal = __.V().has( + T.id, start_node_id + ).addE(edge_label).to(__.V().has(T.id, end_node_id)).property(T.id, edge_id) + edge_traversal = self.get_graph().V().has(T.id, start_node_id).outE(edge_label).has(T.id, edge_id). \ + fold(). \ + coalesce(__.unfold(), create_traversal) + + edge_traversal = NeptuneSessionClient.update_entity_properties_on_traversal(edge_traversal, edge_properties) + edge_traversal.next() + + @staticmethod + def update_entity_properties_on_traversal( + graph_traversal: GraphTraversal, + properties: Dict[str, Any] + ) -> GraphTraversal: + for key, value in properties.items(): + key_split = key.split(':') + key = key_split[0] + value_type = key_split[1] + if "Long" in value_type: + value = int(value) + graph_traversal = graph_traversal.property(key, value) + + return graph_traversal + + @staticmethod + def filter_traversal( + graph_traversal: GraphTraversal, + filter_properties: List[Tuple[str, Any, Callable]], + ) -> GraphTraversal: + for filter_property in filter_properties: + (filter_property_name, filter_property_value, filter_operator) = filter_property + graph_traversal = graph_traversal.has(filter_property_name, filter_operator(filter_property_value)) + return graph_traversal + + def delete_edges( + self, + filter_properties: List[Tuple[str, Any, Callable]], + edge_labels: Optional[List[str]] + ) -> None: + tx = self.get_graph().E() + if edge_labels: + tx = tx.hasLabel(*edge_labels) + tx = NeptuneSessionClient.filter_traversal(tx, filter_properties) + + tx.drop().iterate() + + def delete_nodes( + self, + filter_properties: List[Tuple[str, Any, Callable]], + node_labels: Optional[List[str]] + ) -> None: + tx = self.get_graph().V() + if node_labels: + tx = tx.hasLabel(*node_labels) + tx = NeptuneSessionClient.filter_traversal(tx, filter_properties) + + tx.drop().iterate() diff --git a/databuilder/databuilder/extractor/__init__.py b/databuilder/databuilder/extractor/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/athena_metadata_extractor.py b/databuilder/databuilder/extractor/athena_metadata_extractor.py new file mode 100644 index 0000000000..da26491924 --- /dev/null +++ b/databuilder/databuilder/extractor/athena_metadata_extractor.py @@ -0,0 +1,116 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class AthenaMetadataExtractor(Extractor): + """ + Extracts Athena table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + SQL_STATEMENT = """ + SELECT + {catalog_source} as cluster, table_schema as schema, table_name as name, column_name as col_name, + data_type as col_type,ordinal_position as col_sort_order, + comment as col_description, extra_info as extras from information_schema.columns + {where_clause_suffix} + ORDER by cluster, schema, name, col_sort_order ; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CATALOG_KEY = 'catalog_source' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', CATALOG_KEY: DEFAULT_CLUSTER_NAME} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(AthenaMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(AthenaMetadataExtractor.CATALOG_KEY) + + self.sql_stmt = AthenaMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(AthenaMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + catalog_source=self._cluster + ) + + LOGGER.info('SQL for Athena metadata: %s', self.sql_stmt) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.athena_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], + row['extras'] if row['extras'] is not None else row['col_description'], + row['col_type'], row['col_sort_order'])) + + yield TableMetadata('athena', last_row['cluster'], + last_row['schema'], + last_row['name'], + '', + columns) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/atlas_search_data_extractor.py b/databuilder/databuilder/extractor/atlas_search_data_extractor.py new file mode 100644 index 0000000000..5c8f30399e --- /dev/null +++ b/databuilder/databuilder/extractor/atlas_search_data_extractor.py @@ -0,0 +1,429 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +import multiprocessing.pool +from copy import deepcopy +from functools import reduce +from typing import ( + Any, Dict, Generator, Iterator, List, Optional, Tuple, +) + +from amundsen_common.utils.atlas import AtlasTableKey +from atlasclient.client import Atlas +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor + +LOGGER = logging.getLogger(__name__) + +# custom types +type_fields_mapping_spec = Dict[str, List[Tuple[str, Any, Any, Any]]] +type_fields_mapping = List[Tuple[str, Any, Any, Any]] + +# @todo document classes/methods +# @todo write tests + +__all__ = ['AtlasSearchDataExtractor'] + + +class AtlasSearchDataExtractorHelpers: + @staticmethod + def _filter_none(input_list: List) -> List: + return list(filter(None, input_list)) + + @staticmethod + def get_entity_names(entity_list: Optional[List]) -> List: + entity_list = entity_list or [] + return AtlasSearchDataExtractorHelpers._filter_none( + [e.get('attributes').get('name') for e in entity_list if e.get('status').lower() == 'active']) + + @staticmethod + def get_entity_uri(qualified_name: str, type_name: str) -> str: + key = AtlasTableKey(qualified_name, database=type_name) + return key.amundsen_key + + @staticmethod + def get_entity_descriptions(entity_list: Optional[List]) -> List: + entity_list = entity_list or [] + return AtlasSearchDataExtractorHelpers._filter_none( + [e.get('attributes', dict()).get('description') for e in entity_list + if e.get('status').lower() == 'active']) + + @staticmethod + def get_badges_from_classifications(classifications: Optional[List]) -> List: + classifications = classifications or [] + return AtlasSearchDataExtractorHelpers._filter_none( + [c.get('typeName') for c in classifications if c.get('entityStatus', '').lower() == 'active']) + + @staticmethod + def get_display_text(meanings: Optional[List]) -> List: + meanings = meanings or [] + return AtlasSearchDataExtractorHelpers._filter_none( + [c.get('displayText') for c in meanings if c.get('entityStatus', '').lower() == 'active']) + + @staticmethod + def get_last_successful_execution_timestamp(executions: Optional[List]) -> int: + executions = executions or [] + successful_executions = AtlasSearchDataExtractorHelpers._filter_none( + [e.get('attributes').get('timestamp') for e in executions + if e.get('status', '').lower() == 'active' and e.get('attributes', dict()).get('state') == 'succeeded']) + + try: + return max(successful_executions) + except ValueError: + return 0 + + @staticmethod + def get_chart_names(queries: Optional[List]) -> List[str]: + queries = queries or [] + charts = [] + + for query in queries: + _charts = query.get('relationshipAttributes', dict()).get('charts', []) + charts += _charts + + return AtlasSearchDataExtractorHelpers.get_display_text(charts) + + @staticmethod + def get_table_database(table_key: str) -> str: + result = AtlasTableKey(table_key).get_details().get('database', 'hive_table') + + return result + + @staticmethod + def get_source_description(parameters: Optional[dict]) -> str: + parameters = parameters or dict() + + return parameters.get('sourceDescription', '') + + @staticmethod + def get_usage(readers: Optional[List]) -> Tuple[int, int]: + readers = readers or [] + + score = 0 + unique = 0 + + for reader in readers: + reader_status = reader.get('status') + entity_status = reader.get('relationshipAttributes', dict()).get('entity', dict()).get('entityStatus', '') + relationship_status = reader.get('relationshipAttributes', + dict()).get('entity', + dict()).get('relationshipStatus', '') + + if reader_status == entity_status == relationship_status == 'ACTIVE': + score += reader.get('attributes', dict()).get('count', 0) + + if score > 0: + unique += 1 + + return score, unique + + +class AtlasSearchDataExtractor(Extractor): + ATLAS_URL_CONFIG_KEY = 'atlas_url' + ATLAS_PORT_CONFIG_KEY = 'atlas_port' + ATLAS_PROTOCOL_CONFIG_KEY = 'atlas_protocol' + ATLAS_VALIDATE_SSL_CONFIG_KEY = 'atlas_validate_ssl' + ATLAS_USERNAME_CONFIG_KEY = 'atlas_auth_user' + ATLAS_PASSWORD_CONFIG_KEY = 'atlas_auth_pw' + ATLAS_SEARCH_CHUNK_SIZE_KEY = 'atlas_search_chunk_size' + ATLAS_DETAILS_CHUNK_SIZE_KEY = 'atlas_details_chunk_size' + ATLAS_TIMEOUT_SECONDS_KEY = 'atlas_timeout_seconds' + ATLAS_MAX_RETRIES_KEY = 'atlas_max_retries' + + PROCESS_POOL_SIZE_KEY = 'process_pool_size' + + ENTITY_TYPE_KEY = 'entity_type' + + DEFAULT_CONFIG = ConfigFactory.from_dict({ATLAS_URL_CONFIG_KEY: "localhost", + ATLAS_PORT_CONFIG_KEY: 21000, + ATLAS_PROTOCOL_CONFIG_KEY: 'http', + ATLAS_VALIDATE_SSL_CONFIG_KEY: False, + ATLAS_SEARCH_CHUNK_SIZE_KEY: 250, + ATLAS_DETAILS_CHUNK_SIZE_KEY: 25, + ATLAS_TIMEOUT_SECONDS_KEY: 120, + ATLAS_MAX_RETRIES_KEY: 2, + PROCESS_POOL_SIZE_KEY: 10}) + + # es_document field, atlas field path, modification function, default_value + FIELDS_MAPPING_SPEC: type_fields_mapping_spec = { + 'Table': [ + ('database', 'attributes.qualifiedName', + lambda x: AtlasSearchDataExtractorHelpers.get_table_database(x), None), + ('cluster', 'attributes.qualifiedName', + lambda x: AtlasTableKey(x).get_details()['cluster'], None), + ('schema', 'attributes.qualifiedName', + lambda x: AtlasTableKey(x).get_details()['schema'], None), + ('name', 'attributes.name', None, None), + ('key', ['attributes.qualifiedName', 'typeName'], + lambda x, y: AtlasSearchDataExtractorHelpers.get_entity_uri(x, y), None), + ('description', 'attributes.description', None, None), + ('last_updated_timestamp', 'updateTime', lambda x: int(x) / 1000, 0), + ('total_usage', 'relationshipAttributes.readers', + lambda x: AtlasSearchDataExtractorHelpers.get_usage(x)[0], 0), + ('unique_usage', 'relationshipAttributes.readers', + lambda x: AtlasSearchDataExtractorHelpers.get_usage(x)[1], 0), + ('column_names', 'relationshipAttributes.columns', + lambda x: AtlasSearchDataExtractorHelpers.get_entity_names(x), []), + ('column_descriptions', 'relationshipAttributes.columns', + lambda x: AtlasSearchDataExtractorHelpers.get_entity_descriptions(x), []), + ('tags', 'relationshipAttributes.meanings', + lambda x: AtlasSearchDataExtractorHelpers.get_display_text(x), []), + ('badges', 'classifications', + lambda x: AtlasSearchDataExtractorHelpers.get_badges_from_classifications(x), []), + ('display_name', 'attributes.qualifiedName', + lambda x: '.'.join([AtlasTableKey(x).get_details()['schema'], AtlasTableKey(x).get_details()['table']]), + None), + ('schema_description', 'attributes.parameters', + lambda x: AtlasSearchDataExtractorHelpers.get_source_description(x), ''), + ('programmatic_descriptions', 'attributes.parameters', lambda x: [str(s) for s in list(x.values())], {}) + ], + 'Dashboard': [ + ('group_name', 'relationshipAttributes.group.attributes.name', None, None), + ('name', 'attributes.name', None, None), + ('description', 'attributes.description', None, None), + ('total_usage', 'relationshipAttributes.readers', + lambda x: AtlasSearchDataExtractorHelpers.get_usage(x)[0], 0), + ('product', 'attributes.product', None, None), + ('cluster', 'attributes.cluster', None, None), + ('group_description', 'relationshipAttributes.group.attributes.description', None, None), + ('query_names', 'relationshipAttributes.queries', + lambda x: AtlasSearchDataExtractorHelpers.get_entity_names(x), []), + ('chart_names', 'relationshipAttributes.queries', + lambda x: AtlasSearchDataExtractorHelpers.get_chart_names(x), []), + ('group_url', 'relationshipAttributes.group.attributes.url', None, None), + ('url', 'attributes.url', None, None), + ('uri', 'attributes.qualifiedName', None, None), + ('last_successful_run_timestamp', 'relationshipAttributes.executions', + lambda x: AtlasSearchDataExtractorHelpers.get_last_successful_execution_timestamp(x), None), + ('tags', 'relationshipAttributes.meanings', + lambda x: AtlasSearchDataExtractorHelpers.get_display_text(x), []), + ('badges', 'classifications', + lambda x: AtlasSearchDataExtractorHelpers.get_badges_from_classifications(x), []) + ], + 'User': [ + ('email', 'attributes.qualifiedName', None, ''), + ('first_name', 'attributes.first_name', None, ''), + ('last_name', 'attributes.last_name', None, ''), + ('full_name', 'attributes.full_name', None, ''), + ('github_username', 'attributes.github_username', None, ''), + ('team_name', 'attributes.team_name', None, ''), + ('employee_type', 'attributes.employee_type', None, ''), + ('manager_email', 'attributes.manager_email', None, ''), + ('slack_id', 'attributes.slack_id', None, ''), + ('role_name', 'attributes.role_name', None, ''), + ('is_active', 'attributes.is_active', None, ''), + ('total_read', 'attributes.total_read', None, ''), + ('total_own', 'attributes.total_own', None, ''), + ('total_follow', 'attributes.total_follow', None, '') + ] + } + + ENTITY_MODEL_BY_TYPE = { + 'Table': 'databuilder.models.table_elasticsearch_document.TableESDocument', + 'Dashboard': 'databuilder.models.dashboard_elasticsearch_document.DashboardESDocument', + 'User': 'databuilder.models.user_elasticsearch_document.UserESDocument' + } + + REQUIRED_RELATIONSHIPS_BY_TYPE = { + 'Table': ['columns', 'readers'], + 'Dashboard': ['group', 'charts', 'executions', 'queries'], + 'User': [] + } + + def init(self, conf: ConfigTree) -> None: + self.conf = conf.with_fallback(AtlasSearchDataExtractor.DEFAULT_CONFIG) + self.driver = self._get_driver() + + self._extract_iter: Optional[Iterator[Any]] = None + + @property + def entity_type(self) -> str: + return self.conf.get(AtlasSearchDataExtractor.ENTITY_TYPE_KEY) + + @property + def dsl_search_query(self) -> Dict: + query = { + 'query': f'{self.entity_type} where __state = "ACTIVE"' + } + + LOGGER.debug(f'DSL Search Query: {query}') + + return query + + @property + def model_class(self) -> Any: + model_class = AtlasSearchDataExtractor.ENTITY_MODEL_BY_TYPE.get(self.entity_type) + + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + + return getattr(mod, class_name) + + @property + def field_mappings(self) -> type_fields_mapping: + return AtlasSearchDataExtractor.FIELDS_MAPPING_SPEC.get(self.entity_type) or [] + + @property + def search_chunk_size(self) -> int: + return self.conf.get_int(AtlasSearchDataExtractor.ATLAS_SEARCH_CHUNK_SIZE_KEY) + + @property + def relationships(self) -> Optional[List[str]]: + return AtlasSearchDataExtractor.REQUIRED_RELATIONSHIPS_BY_TYPE.get(self.entity_type) # type: ignore + + def extract(self) -> Any: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.atlas_search_data' + + def _get_driver(self) -> Any: + return Atlas(host=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_URL_CONFIG_KEY), + port=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PORT_CONFIG_KEY), + username=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_USERNAME_CONFIG_KEY), + password=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PASSWORD_CONFIG_KEY), + protocol=self.conf.get_string(AtlasSearchDataExtractor.ATLAS_PROTOCOL_CONFIG_KEY), + validate_ssl=self.conf.get_bool(AtlasSearchDataExtractor.ATLAS_VALIDATE_SSL_CONFIG_KEY), + timeout=self.conf.get_int(AtlasSearchDataExtractor.ATLAS_TIMEOUT_SECONDS_KEY), + max_retries=self.conf.get_int(AtlasSearchDataExtractor.ATLAS_MAX_RETRIES_KEY)) + + def _get_latest_entity_metrics(self) -> Optional[dict]: + admin_metrics = list(self.driver.admin_metrics) + + try: + return admin_metrics[-1].entity + except Exception: + return None + + def _get_count_of_active_entities(self) -> int: + entity_metrics = self._get_latest_entity_metrics() + + if entity_metrics: + count = entity_metrics.get('entityActive-typeAndSubTypes', dict()).get(self.entity_type, 0) + + return int(count) + else: + return 0 + + def _get_entity_guids(self, start_offset: int) -> List[str]: + result = [] + + batch_start = start_offset + batch_end = start_offset + self.search_chunk_size + + LOGGER.info(f'Collecting guids for batch: {batch_start}-{batch_end}') + + _params = {'offset': str(batch_start), 'limit': str(self.search_chunk_size)} + + full_params = deepcopy(self.dsl_search_query) + full_params.update(**_params) + + try: + results = self.driver.search_dsl(**full_params) + + for hit in results: + for entity in hit.entities: + result.append(entity.guid) + + return result + except Exception: + LOGGER.warning(f'Error processing batch: {batch_start}-{batch_end}', exc_info=True) + + return [] + + def _get_entity_details(self, guid_list: List[str]) -> List: + result = [] + + LOGGER.info(f'Processing guids chunk of size: {len(guid_list)}') + + try: + bulk_collection = self.driver.entity_bulk(guid=guid_list) + + for collection in bulk_collection: + search_chunk = list(collection.entities_with_relationships(attributes=self.relationships)) + + result += search_chunk + + return result + except Exception: + LOGGER.warning(f'Error processing guids. {len(guid_list)}', exc_info=True) + + return [] + + @staticmethod + def split_list_to_chunks(input_list: List[Any], n: int) -> Generator: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(input_list), n): + yield input_list[i:i + n] + + def _execute_query(self) -> Any: + details_chunk_size = self.conf.get_int(AtlasSearchDataExtractor.ATLAS_DETAILS_CHUNK_SIZE_KEY) + process_pool_size = self.conf.get_int(AtlasSearchDataExtractor.PROCESS_POOL_SIZE_KEY) + + guids = [] + + entity_count = self._get_count_of_active_entities() + + LOGGER.info(f'Received count: {entity_count}') + + if entity_count > 0: + offsets = [i * self.search_chunk_size for i in range(int(entity_count / self.search_chunk_size) + 1)] + else: + offsets = [] + + with multiprocessing.pool.ThreadPool(processes=process_pool_size) as pool: + guid_list = pool.map(self._get_entity_guids, offsets, chunksize=1) + + for sub_list in guid_list: + guids += sub_list + + LOGGER.info(f'Received guids: {len(guids)}') + + if guids: + guids_chunks = AtlasSearchDataExtractor.split_list_to_chunks(guids, details_chunk_size) + + with multiprocessing.pool.ThreadPool(processes=process_pool_size) as pool: + return_list = pool.map(self._get_entity_details, guids_chunks) + + for sub_list in return_list: + for entry in sub_list: + yield entry + + def _get_extract_iter(self) -> Iterator[Any]: + for atlas_entity in self._execute_query(): + model_dict = dict() + + try: + data = atlas_entity.__dict__['_data'] + + for spec in self.field_mappings: + model_field, atlas_fields_paths, _transform_spec, default_value = spec + + if not isinstance(atlas_fields_paths, list): + atlas_fields_paths = [atlas_fields_paths] + + atlas_values = [] + for atlas_field_path in atlas_fields_paths: + + atlas_value = reduce(lambda x, y: x.get(y, dict()), atlas_field_path.split('.'), + data) or default_value + atlas_values.append(atlas_value) + + transform_spec = _transform_spec or (lambda x: x) + + es_entity_value = transform_spec(*atlas_values) + model_dict[model_field] = es_entity_value + + yield self.model_class(**model_dict) + except Exception: + LOGGER.warning('Error building model object.', exc_info=True) diff --git a/databuilder/databuilder/extractor/base_bigquery_extractor.py b/databuilder/databuilder/extractor/base_bigquery_extractor.py new file mode 100644 index 0000000000..debe359efd --- /dev/null +++ b/databuilder/databuilder/extractor/base_bigquery_extractor.py @@ -0,0 +1,169 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import re +from collections import namedtuple +from datetime import datetime, timezone +from typing import ( + Any, Dict, Iterator, List, +) + +import google.oauth2.service_account +import google_auth_httplib2 +import httplib2 +from googleapiclient.discovery import build +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor + +DatasetRef = namedtuple('DatasetRef', ['datasetId', 'projectId']) +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class BaseBigQueryExtractor(Extractor): + PROJECT_ID_KEY = 'project_id' + KEY_PATH_KEY = 'key_path' + # sometimes we don't have a key path, but only have an variable + CRED_KEY = 'project_cred' + PAGE_SIZE_KEY = 'page_size' + FILTER_KEY = 'filter' + # metadata for tables created after the cutoff time would not be extracted from bigquery. + CUTOFF_TIME_KEY = 'cutoff_time' + _DEFAULT_SCOPES = ['https://www.googleapis.com/auth/bigquery.readonly'] + DEFAULT_PAGE_SIZE = 300 + NUM_RETRIES = 3 + DATE_LENGTH = 8 + DATE_TIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + + def init(self, conf: ConfigTree) -> None: + # should use key_path, or cred_key if the former doesn't exist + self.key_path = conf.get_string(BaseBigQueryExtractor.KEY_PATH_KEY, None) + self.cred_key = conf.get_string(BaseBigQueryExtractor.CRED_KEY, None) + self.project_id = conf.get_string(BaseBigQueryExtractor.PROJECT_ID_KEY) + self.pagesize = conf.get_int( + BaseBigQueryExtractor.PAGE_SIZE_KEY, + BaseBigQueryExtractor.DEFAULT_PAGE_SIZE) + self.filter = conf.get_string(BaseBigQueryExtractor.FILTER_KEY, '') + self.cutoff_time = conf.get_string(BaseBigQueryExtractor.CUTOFF_TIME_KEY, + datetime.now(timezone.utc).strftime(BaseBigQueryExtractor.DATE_TIME_FORMAT)) + + if self.key_path: + credentials = ( + google.oauth2.service_account.Credentials.from_service_account_file( + self.key_path, scopes=self._DEFAULT_SCOPES)) + else: + if self.cred_key: + service_account_info = json.loads(self.cred_key) + credentials = ( + google.oauth2.service_account.Credentials.from_service_account_info( + service_account_info, scopes=self._DEFAULT_SCOPES)) + else: + # FIXME: mypy can't find this attribute + google_auth: Any = getattr(google, 'auth') + credentials, _ = google_auth.default(scopes=self._DEFAULT_SCOPES) + + http = httplib2.Http() + authed_http = google_auth_httplib2.AuthorizedHttp(credentials, http=http) + self.bigquery_service = build('bigquery', 'v2', http=authed_http, cache_discovery=False) + self.logging_service = build('logging', 'v2', http=authed_http, cache_discovery=False) + self.iter: Iterator[Any] = iter([]) + + def extract(self) -> Any: + try: + return next(self.iter) + except StopIteration: + return None + + def _is_sharded_table(self, table_id: str) -> bool: + """ + Table with a numeric suffix starting with a date string + will be considered as a sharded table + :param table_id: + :return: + """ + suffix = self._get_sharded_table_suffix(table_id) + if len(suffix) < BaseBigQueryExtractor.DATE_LENGTH: + return False + + suffix_date = suffix[:BaseBigQueryExtractor.DATE_LENGTH] + try: + datetime.strptime(suffix_date, '%Y%m%d') + return True + except ValueError: + return False + + def _get_sharded_table_suffix(self, table_id: str) -> str: + suffix_match = re.search(r'\d+$', table_id) + suffix = suffix_match.group() if suffix_match else '' + return suffix + + def _iterate_over_tables(self) -> Any: + for dataset in self._retrieve_datasets(): + for entry in self._retrieve_tables(dataset): + yield entry + + # TRICKY: this function has different return types between different subclasses, + # so type as Any. Should probably refactor to remove this unclear sharing. + def _retrieve_tables(self, dataset: DatasetRef) -> Any: + pass + + def _retrieve_datasets(self) -> List[DatasetRef]: + datasets = [] + for page in self._page_dataset_list_results(): + if 'datasets' not in page: + continue + + for dataset in page['datasets']: + dataset_ref = dataset['datasetReference'] + ref = DatasetRef(**dataset_ref) + datasets.append(ref) + + return datasets + + def _page_dataset_list_results(self) -> Iterator[Any]: + response = self.bigquery_service.datasets().list( + projectId=self.project_id, + all=False, # Do not return hidden datasets + filter=self.filter, + maxResults=self.pagesize).execute( + num_retries=BaseBigQueryExtractor.NUM_RETRIES) + + while response: + yield response + + if 'nextPageToken' in response: + response = self.bigquery_service.datasets().list( + projectId=self.project_id, + all=True, + filter=self.filter, + pageToken=response['nextPageToken']).execute( + num_retries=BaseBigQueryExtractor.NUM_RETRIES) + else: + response = None + + def _page_table_list_results(self, dataset: DatasetRef) -> Iterator[Dict[str, Any]]: + response = self.bigquery_service.tables().list( + projectId=dataset.projectId, + datasetId=dataset.datasetId, + maxResults=self.pagesize).execute( + num_retries=BaseBigQueryExtractor.NUM_RETRIES) + + while response: + yield response + + if 'nextPageToken' in response: + response = self.bigquery_service.tables().list( + projectId=dataset.projectId, + datasetId=dataset.datasetId, + maxResults=self.pagesize, + pageToken=response['nextPageToken']).execute( + num_retries=BaseBigQueryExtractor.NUM_RETRIES) + else: + response = None + + def get_scope(self) -> str: + return 'extractor.bigquery_table_metadata' diff --git a/databuilder/databuilder/extractor/base_extractor.py b/databuilder/databuilder/extractor/base_extractor.py new file mode 100644 index 0000000000..7d9d52da9d --- /dev/null +++ b/databuilder/databuilder/extractor/base_extractor.py @@ -0,0 +1,29 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Any + +from pyhocon import ConfigTree + +from databuilder import Scoped + + +class Extractor(Scoped): + """ + An extractor extracts record + """ + + @abc.abstractmethod + def init(self, conf: ConfigTree) -> None: + pass + + @abc.abstractmethod + def extract(self) -> Any: + """ + :return: Provides a record or None if no more to extract + """ + return None + + def get_scope(self) -> str: + return 'extractor' diff --git a/databuilder/databuilder/extractor/base_postgres_metadata_extractor.py b/databuilder/databuilder/extractor/base_postgres_metadata_extractor.py new file mode 100644 index 0000000000..6666cbd474 --- /dev/null +++ b/databuilder/databuilder/extractor/base_postgres_metadata_extractor.py @@ -0,0 +1,117 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class BasePostgresMetadataExtractor(Extractor): + """ + Extracts Postgres table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: 'true', CLUSTER_KEY: DEFAULT_CLUSTER_NAME, USE_CATALOG_AS_CLUSTER_NAME: True} + ) + + @abc.abstractmethod + def get_sql_statement(self, use_catalog_as_cluster_name: bool, where_clause_suffix: str) -> Any: + """ + :return: Provides a record or None if no more to extract + """ + return None + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(BasePostgresMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(BasePostgresMetadataExtractor.CLUSTER_KEY) + + self._database = conf.get_string(BasePostgresMetadataExtractor.DATABASE_KEY, default='postgres') + + self.sql_stmt = self.get_sql_statement( + use_catalog_as_cluster_name=conf.get_bool(BasePostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME), + where_clause_suffix=conf.get_string(BasePostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info('SQL for postgres metadata: %s', self.sql_stmt) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order'])) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + last_row['description'], + columns) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/base_teradata_metadata_extractor.py b/databuilder/databuilder/extractor/base_teradata_metadata_extractor.py new file mode 100644 index 0000000000..56fcba063f --- /dev/null +++ b/databuilder/databuilder/extractor/base_teradata_metadata_extractor.py @@ -0,0 +1,141 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple("TableKey", ["schema", "table_name"]) + +LOGGER = logging.getLogger(__name__) + + +class BaseTeradataMetadataExtractor(Extractor): + """ + Extracts Teradata table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = "where_clause_suffix" + CLUSTER_KEY = "cluster_key" + USE_CATALOG_AS_CLUSTER_NAME = "use_catalog_as_cluster_name" + DATABASE_KEY = "database_key" + + # Default values + DEFAULT_CLUSTER_NAME = "master" + + DEFAULT_CONFIG = ConfigFactory.from_dict( + { + WHERE_CLAUSE_SUFFIX_KEY: "true", + CLUSTER_KEY: DEFAULT_CLUSTER_NAME, + USE_CATALOG_AS_CLUSTER_NAME: True, + } + ) + + @abc.abstractmethod + def get_sql_statement( + self, use_catalog_as_cluster_name: bool, where_clause_suffix: str + ) -> Any: + """ + :return: Provides a record or None if no more to extract + """ + return None + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(BaseTeradataMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(BaseTeradataMetadataExtractor.CLUSTER_KEY) + + self._database = conf.get_string( + BaseTeradataMetadataExtractor.DATABASE_KEY, default="teradata" + ) + + self.sql_stmt = self.get_sql_statement( + use_catalog_as_cluster_name=conf.get_bool( + BaseTeradataMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME + ), + where_clause_suffix=conf.get_string( + BaseTeradataMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY + ), + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf( + conf, self._alchemy_extractor.get_scope() + ).with_fallback( + ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}) + ) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info("SQL for teradata metadata: %s", self.sql_stmt) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append( + ColumnMetadata( + row["col_name"], + row["col_description"], + row["col_type"], + row["col_sort_order"], + ) + ) + + yield TableMetadata( + self._database, + last_row["td_cluster"], + last_row["schema"], + last_row["name"], + last_row["description"], + columns, + ) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row["schema"], table_name=row["name"]) + + return None diff --git a/databuilder/databuilder/extractor/bigquery_metadata_extractor.py b/databuilder/databuilder/extractor/bigquery_metadata_extractor.py new file mode 100644 index 0000000000..0c459d07f7 --- /dev/null +++ b/databuilder/databuilder/extractor/bigquery_metadata_extractor.py @@ -0,0 +1,131 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Callable, Dict, List, Set, cast, +) + +from googleapiclient.errors import HttpError +from pyhocon import ConfigTree + +from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor, DatasetRef +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +LOGGER = logging.getLogger(__name__) + + +class BigQueryMetadataExtractor(BaseBigQueryExtractor): + + """ A metadata extractor for bigquery tables, taking the schema metadata + from the google cloud bigquery API's. This extractor goes through all visible + datasets in the project identified by project_id and iterates over all tables + it finds. A separate account is configurable through the key_path parameter, + which should point to a valid json file corresponding to a service account. + + This extractor supports nested columns, which are delimited by a dot (.) in the + column name. + """ + + def init(self, conf: ConfigTree) -> None: + BaseBigQueryExtractor.init(self, conf) + self.iter = iter(self._iterate_over_tables()) + + def _retrieve_tables(self, dataset: DatasetRef) -> Any: # noqa: max-complexity: 12 + grouped_tables: Set[str] = set([]) + + for page in self._page_table_list_results(dataset): + if 'tables' not in page: + continue + + for table in page['tables']: + tableRef = table['tableReference'] + table_id = tableRef['tableId'] + + # BigQuery tables that have numeric suffix starting with a date string will be + # considered date range tables. + # ( e.g. ga_sessions_20190101, ga_sessions_20190102, etc. ) + if self._is_sharded_table(table_id): + # Sharded tables have numeric suffix starting with a date string + # and then we only need one schema definition + table_prefix = table_id[:-len(self._get_sharded_table_suffix(table_id))] + if table_prefix in grouped_tables: + # If one table in the date range is processed, then ignore other ones + # (it adds too much metadata) + continue + + table_id = table_prefix + grouped_tables.add(table_prefix) + + try: + table = self.bigquery_service.tables().get( + projectId=tableRef['projectId'], + datasetId=tableRef['datasetId'], + tableId=tableRef['tableId']).execute(num_retries=BigQueryMetadataExtractor.NUM_RETRIES) + except HttpError as err: + # While iterating over the tables in a dataset, some temporary tables might be deleted + # this causes 404 errors, so we should handle them gracefully + LOGGER.error(err) + continue + + # BigQuery tables also have interesting metadata about partitioning + # data location (EU/US), mod/create time, etc... Extract that some other time? + cols: List[ColumnMetadata] = [] + # Not all tables have schemas + if 'schema' in table: + schema = table['schema'] + if 'fields' in schema: + total_cols = 0 + for column in schema['fields']: + # TRICKY: this mutates :cols: + total_cols = self._iterate_over_cols('', column, cols, total_cols + 1) + + table_meta = TableMetadata( + database='bigquery', + cluster=tableRef['projectId'], + schema=tableRef['datasetId'], + name=table_id, + description=table.get('description', None), + columns=cols, + is_view=table['type'] == 'VIEW') + + yield table_meta + + def _iterate_over_cols(self, + parent: str, + column: Dict[str, str], + cols: List[ColumnMetadata], + total_cols: int) -> int: + get_column_type: Callable[[dict], str] = lambda column: column['type'] + ':' + column['mode']\ + if column.get('mode') else column['type'] + if len(parent) > 0: + col_name = f'{parent}.{column["name"]}' + else: + col_name = column['name'] + + if column['type'] == 'RECORD': + col = ColumnMetadata( + name=col_name, + description=column.get('description', None), + col_type=get_column_type(column), + sort_order=total_cols) + cols.append(col) + total_cols += 1 + for field in column['fields']: + # TODO field is actually a TableFieldSchema, per + # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableFieldSchema + # however it's typed as str, which is incorrect. Work-around by casting. + field_casted = cast(Dict[str, str], field) + total_cols = self._iterate_over_cols(col_name, field_casted, cols, total_cols) + return total_cols + else: + col = ColumnMetadata( + name=col_name, + description=column.get('description', None), + col_type=get_column_type(column), + sort_order=total_cols) + cols.append(col) + return total_cols + 1 + + def get_scope(self) -> str: + return 'extractor.bigquery_table_metadata' diff --git a/databuilder/databuilder/extractor/bigquery_usage_extractor.py b/databuilder/databuilder/extractor/bigquery_usage_extractor.py new file mode 100644 index 0000000000..bb71256c01 --- /dev/null +++ b/databuilder/databuilder/extractor/bigquery_usage_extractor.py @@ -0,0 +1,202 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from collections import namedtuple +from datetime import ( + datetime, timedelta, timezone, +) +from time import sleep +from typing import ( + Any, Dict, Iterator, List, Optional, Tuple, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor + +TableColumnUsageTuple = namedtuple('TableColumnUsageTuple', ['database', 'cluster', 'schema', + 'table', 'column', 'email']) + +LOGGER = logging.getLogger(__name__) + + +class BigQueryTableUsageExtractor(BaseBigQueryExtractor): + """ + An aggregate extractor for bigquery table usage. This class takes the data from + the stackdriver logging API by filtering on timestamp, bigquery_resource and looking + for referencedTables in the response. + """ + TIMESTAMP_KEY = 'timestamp' + _DEFAULT_SCOPES = ['https://www.googleapis.com/auth/cloud-platform'] + EMAIL_PATTERN = 'email_pattern' + DELAY_TIME = 'delay_time' + TABLE_DECORATORS = ['$', '@'] + COUNT_READS_ONLY_FROM_PROJECT_ID_KEY = 'count_reads_only_from_project_id_key' + + def init(self, conf: ConfigTree) -> None: + BaseBigQueryExtractor.init(self, conf) + self.timestamp = conf.get_string( + BigQueryTableUsageExtractor.TIMESTAMP_KEY, + (datetime.now(timezone.utc) - timedelta(days=1)).strftime(BigQueryTableUsageExtractor.DATE_TIME_FORMAT)) + + self.email_pattern = conf.get_string(BigQueryTableUsageExtractor.EMAIL_PATTERN, None) + self.delay_time = conf.get_int(BigQueryTableUsageExtractor.DELAY_TIME, 100) + self.table_usage_counts: Dict[TableColumnUsageTuple, int] = {} + # GCP console allows running queries using tables from a project different from the one the extractor is + # used for; only usage metadata of referenced tables present in the given project_id_key for the + # extractor is taken into account and usage metadata of referenced tables from other projects + # is ignored by "default". + self.count_reads_only_from_same_project = conf.get_bool( + BigQueryTableUsageExtractor.COUNT_READS_ONLY_FROM_PROJECT_ID_KEY, True) + self._count_usage() + self.iter = iter(self.table_usage_counts) + + def _count_usage(self) -> None: # noqa: C901 + count = 0 + for entry in self._retrieve_records(): + count += 1 + if count % self.pagesize == 0: + LOGGER.info(f'Aggregated {count} records') + + if entry is None: + continue + + try: + job = entry['protoPayload']['serviceData']['jobCompletedEvent']['job'] + except Exception: + continue + if job['jobStatus']['state'] != 'DONE': + # This job seems not to have finished yet, so we ignore it. + continue + if len(job['jobStatus'].get('error', {})) > 0: + # This job has errors, so we ignore it + continue + + email = entry['protoPayload']['authenticationInfo']['principalEmail'] + # Query results can be cached and if the source tables remain untouched, + # bigquery will return it from a 24 hour cache result instead. In that + # case, referencedTables has been observed to be empty: + # https://cloud.google.com/logging/docs/reference/audit/bigquery/rest/Shared.Types/AuditData#JobStatistics + + refTables = job['jobStatistics'].get('referencedTables', None) + if refTables: + if 'totalTablesProcessed' in job['jobStatistics']: + self._create_records( + refTables, + job['jobStatistics']['totalTablesProcessed'], email, + job['jobName']['jobId']) + + refViews = job['jobStatistics'].get('referencedViews', None) + if refViews: + if 'totalViewsProcessed' in job['jobStatistics']: + self._create_records( + refViews, job['jobStatistics']['totalViewsProcessed'], + email, job['jobName']['jobId']) + + def _create_records(self, refResources: List[dict], resourcesProcessed: int, email: str, + jobId: str) -> None: + # if email filter is provided, only the email matched with filter will be recorded. + if self.email_pattern: + if not re.match(self.email_pattern, email): + # the usage account not match email pattern + return + + if len(refResources) != resourcesProcessed: + LOGGER.warning(f'The number of tables listed in job {jobId} is not consistent') + return + + for refResource in refResources: + tableId = refResource.get('tableId') + datasetId = refResource.get('datasetId') + + if not datasetId or not tableId: + # handling case when the referenced table is an external table + # Which doesn't have a datasetId + continue + + if self._is_anonymous_dataset(datasetId) or self._is_wildcard_table(tableId): + continue + + tableId = self._remove_table_decorators(tableId) + + if self._is_sharded_table(tableId): + # Use the prefix of the sharded table as tableId + tableId = tableId[:-len(self._get_sharded_table_suffix(tableId))] + + if refResource['projectId'] != self.project_id and self.count_reads_only_from_same_project: + LOGGER.debug( + f'Not counting usage for {refResource} since {tableId} ' + f'is not present in {self.project_id} ' + f'and {BigQueryTableUsageExtractor.COUNT_READS_ONLY_FROM_PROJECT_ID_KEY} is True') + continue + else: + key = TableColumnUsageTuple(database='bigquery', + cluster=refResource['projectId'], + schema=datasetId, + table=tableId, + column='*', + email=email) + + new_count = self.table_usage_counts.get(key, 0) + 1 + self.table_usage_counts[key] = new_count + + def _retrieve_records(self) -> Iterator[Optional[Dict]]: + """ + Extracts bigquery log data by looking at the principalEmail in the authenticationInfo block and + referencedTables in the jobStatistics and filters out log entries of metadata queries. + :return: Provides a record or None if no more to extract + """ + body = { + 'resourceNames': [f'projects/{self.project_id}'], + 'pageSize': self.pagesize, + 'filter': 'protoPayload.methodName="jobservice.jobcompleted" AND ' + 'resource.type="bigquery_resource" AND ' + 'NOT protoPayload.serviceData.jobCompletedEvent.job.jobConfiguration.query.query:(' + 'INFORMATION_SCHEMA OR __TABLES__) AND ' + f'timestamp >= "{self.timestamp}" AND timestamp < "{self.cutoff_time}"' + } + for page in self._page_over_results(body): + for entry in page['entries']: + yield entry + + def extract(self) -> Optional[Tuple[Any, int]]: + try: + key = next(self.iter) + return key, self.table_usage_counts[key] + except StopIteration: + return None + + def _page_over_results(self, body: Dict) -> Iterator[Dict]: + response = self.logging_service.entries().list(body=body).execute( + num_retries=BigQueryTableUsageExtractor.NUM_RETRIES) + while response: + if 'entries' in response: + yield response + + try: + if 'nextPageToken' in response: + body['pageToken'] = response['nextPageToken'] + response = self.logging_service.entries().list(body=body).execute( + num_retries=BigQueryTableUsageExtractor.NUM_RETRIES) + else: + response = None + except Exception: + # Add a delay when BQ quota exceeds limitation + sleep(self.delay_time) + + def _remove_table_decorators(self, tableId: str) -> Optional[str]: + for decorator in BigQueryTableUsageExtractor.TABLE_DECORATORS: + tableId = tableId.split(decorator)[0] + return tableId + + def _is_anonymous_dataset(self, datasetId: str) -> bool: + # temporary/cached results tables are stored in anonymous datasets that have names starting with '_' + return datasetId.startswith('_') + + def _is_wildcard_table(self, tableId: str) -> bool: + return '*' in tableId + + def get_scope(self) -> str: + return 'extractor.bigquery_table_usage' diff --git a/databuilder/databuilder/extractor/bigquery_watermark_extractor.py b/databuilder/databuilder/extractor/bigquery_watermark_extractor.py new file mode 100644 index 0000000000..245c286f18 --- /dev/null +++ b/databuilder/databuilder/extractor/bigquery_watermark_extractor.py @@ -0,0 +1,157 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import logging +import textwrap +import time +from calendar import timegm +from collections import namedtuple +from typing import ( + Any, Dict, Iterator, List, Tuple, Union, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor, DatasetRef +from databuilder.models.watermark import Watermark + +PartitionInfo = namedtuple('PartitionInfo', ['partition_id', 'epoch_created']) + +LOGGER = logging.getLogger(__name__) + + +class BigQueryWatermarkExtractor(BaseBigQueryExtractor): + + def init(self, conf: ConfigTree) -> None: + BaseBigQueryExtractor.init(self, conf) + self.iter: Iterator[Watermark] = iter(self._iterate_over_tables()) + + def get_scope(self) -> str: + return 'extractor.bigquery_watermarks' + + def _retrieve_tables(self, + dataset: DatasetRef + ) -> Iterator[Watermark]: + sharded_table_watermarks: Dict[str, Dict[str, Union[str, Any]]] = {} + cutoff_time_in_epoch = timegm(time.strptime(self.cutoff_time, BigQueryWatermarkExtractor.DATE_TIME_FORMAT)) + + for page in self._page_table_list_results(dataset): + if 'tables' not in page: + continue + + for table in page['tables']: + tableRef = table['tableReference'] + table_id = tableRef['tableId'] + table_creation_time = float(table['creationTime']) / 1000 + # only extract watermark metadata for tables created before the cut-off time + if table_creation_time < cutoff_time_in_epoch: + # BigQuery tables that have numeric suffix starts with a date are + # considered date range tables. + # ( e.g. ga_sessions_20190101, ga_sessions_20190102, etc. ) + # We use these dates in the suffixes to determine high and low watermarks + if self._is_sharded_table(table_id): + suffix = self._get_sharded_table_suffix(table_id) + prefix = table_id[:-len(suffix)] + date = suffix[:BaseBigQueryExtractor.DATE_LENGTH] + + if prefix in sharded_table_watermarks: + sharded_table_watermarks[prefix]['low'] = min( + sharded_table_watermarks[prefix]['low'], date) + sharded_table_watermarks[prefix]['high'] = max( + sharded_table_watermarks[prefix]['high'], date) + else: + sharded_table_watermarks[prefix] = {'high': date, 'low': date, 'table': table} + else: + partitions = self._get_partitions(table, tableRef) + if not partitions: + continue + low, high = self._get_partition_watermarks(table, tableRef, partitions) + yield low + yield high + + for prefix, td in sharded_table_watermarks.items(): + table = td['table'] + tableRef = table['tableReference'] + + yield Watermark( + datetime.datetime.fromtimestamp(float(table['creationTime']) / 1000).strftime('%Y-%m-%d %H:%M:%S'), + 'bigquery', + tableRef['datasetId'], + prefix, + f'__table__={td["low"]}', + part_type="low_watermark", + cluster=tableRef['projectId'] + ) + + yield Watermark( + datetime.datetime.fromtimestamp(float(table['creationTime']) / 1000).strftime('%Y-%m-%d %H:%M:%S'), + 'bigquery', + tableRef['datasetId'], + prefix, + f'__table__={td["high"]}', + part_type="high_watermark", + cluster=tableRef['projectId'] + ) + + def _get_partitions(self, + table: str, + tableRef: Dict[str, str] + ) -> List[PartitionInfo]: + if 'timePartitioning' not in table: + return [] + + query = textwrap.dedent(""" + SELECT partition_id, + TIMESTAMP(creation_time/1000) AS creation_time + FROM [{project}:{dataset}.{table}$__PARTITIONS_SUMMARY__] + WHERE partition_id <> '__UNPARTITIONED__' + AND partition_id <> '__NULL__' + """) + body = { + 'query': query.format( + project=tableRef['projectId'], + dataset=tableRef['datasetId'], + table=tableRef['tableId']), + 'useLegacySql': True + } + result = self.bigquery_service.jobs().query(projectId=self.project_id, body=body).execute() + + if 'rows' not in result: + return [] + + return [PartitionInfo(row['f'][0]['v'], row['f'][1]['v']) for row in result['rows']] + + def _get_partition_watermarks(self, + table: Dict[str, Any], + tableRef: Dict[str, str], + partitions: List[PartitionInfo] + ) -> Tuple[Watermark, Watermark]: + if 'field' in table['timePartitioning']: + field = table['timePartitioning']['field'] + else: + field = '_PARTITIONTIME' + + low = min(partitions, key=lambda t: t.partition_id) + low_wm = Watermark( + datetime.datetime.fromtimestamp(float(low.epoch_created)).strftime('%Y-%m-%d %H:%M:%S'), + 'bigquery', + tableRef['datasetId'], + tableRef['tableId'], + f'{field}={low.partition_id}', + part_type="low_watermark", + cluster=tableRef['projectId'] + ) + + high = max(partitions, key=lambda t: t.partition_id) + high_wm = Watermark( + datetime.datetime.fromtimestamp(float(high.epoch_created)).strftime('%Y-%m-%d %H:%M:%S'), + 'bigquery', + tableRef['datasetId'], + tableRef['tableId'], + f'{field}={high.partition_id}', + part_type="high_watermark", + cluster=tableRef['projectId'] + ) + + return low_wm, high_wm diff --git a/databuilder/databuilder/extractor/cassandra_extractor.py b/databuilder/databuilder/extractor/cassandra_extractor.py new file mode 100644 index 0000000000..81b6ddc88d --- /dev/null +++ b/databuilder/databuilder/extractor/cassandra_extractor.py @@ -0,0 +1,103 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Dict, Iterator, Union, +) + +import cassandra.metadata +from cassandra.cluster import Cluster +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class CassandraExtractor(Extractor): + """ + Extracts tables and columns metadata from Apacha Cassandra + """ + + CLUSTER_KEY = 'cluster' + # Key to define clusters ips, it should be List[str] + IPS_KEY = 'ips' + # Key to define extra kwargs to pass on cluster constructor, + # it should be Dict[Any] + KWARGS_KEY = 'kwargs' + # Key to define custom filter function based on keyspace and table + # since the cluster metadata doesn't support native filters, + # it should be like def filter(keyspace: str, table: str) -> bool and return False if + # going to skip that table and True if not + FILTER_FUNCTION_KEY = 'filter' + + # Default values + DEFAULT_CONFIG = ConfigFactory.from_dict({ + CLUSTER_KEY: 'gold', + IPS_KEY: [], + KWARGS_KEY: {}, + FILTER_FUNCTION_KEY: None + }) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(CassandraExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(CassandraExtractor.CLUSTER_KEY) + self._filter = conf.get(CassandraExtractor.FILTER_FUNCTION_KEY) + ips = conf.get_list(CassandraExtractor.IPS_KEY) + kwargs = conf.get(CassandraExtractor.KWARGS_KEY) + self._client = Cluster(ips, **kwargs) + self._client.connect() + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.cassandra' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + It gets all tables and yields TableMetadata + :return: + """ + keyspaces = self._get_keyspaces() + for keyspace in keyspaces: + # system keyspaces + if keyspace.startswith('system'): + continue + for table in self._get_tables(keyspace): + if self._filter and not self._filter(keyspace, table): + continue + + columns = [] + + columns_dict = self._get_columns(keyspace, table) + for idx, (column_name, column) in enumerate(columns_dict.items()): + columns.append(ColumnMetadata( + column_name, + None, + column.cql_type, + idx + )) + + yield TableMetadata( + 'cassandra', + self._cluster, + keyspace, + table, + None, + columns + ) + + def _get_keyspaces(self) -> Dict[str, cassandra.metadata.KeyspaceMetadata]: + return self._client.metadata.keyspaces + + def _get_tables(self, keyspace: str) -> Dict[str, cassandra.metadata.TableMetadata]: + return self._client.metadata.keyspaces[keyspace].tables + + def _get_columns(self, keyspace: str, table: str) -> Dict[str, cassandra.metadata.ColumnMetadata]: + return self._client.metadata.keyspaces[keyspace].tables[table].columns diff --git a/databuilder/databuilder/extractor/csv_extractor.py b/databuilder/databuilder/extractor/csv_extractor.py new file mode 100644 index 0000000000..10b33eb825 --- /dev/null +++ b/databuilder/databuilder/extractor/csv_extractor.py @@ -0,0 +1,362 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import importlib +from collections import defaultdict +from typing import Any, List + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.table_lineage import ColumnLineage, TableLineage +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +def split_badge_list(badges: str, separator: str) -> List[str]: + """ + Splits a string of badges into a list, removing all empty badges. + """ + if badges is None: + return [] + + return [badge for badge in badges.split(separator) if badge] + + +class CsvExtractor(Extractor): + # Config keys + FILE_LOCATION = 'file_location' + + """ + An Extractor that extracts records via CSV. + """ + + def init(self, conf: ConfigTree) -> None: + """ + :param conf: + """ + self.conf = conf + self.file_location = conf.get_string(CsvExtractor.FILE_LOCATION) + + model_class = conf.get('model_class', None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + self._load_csv() + + def _load_csv(self) -> None: + """ + Create an iterator to execute sql. + """ + if not hasattr(self, 'results'): + with open(self.file_location, 'r') as fin: + self.results = [dict(i) for i in csv.DictReader(fin)] + + if hasattr(self, 'model_class'): + results = [self.model_class(**result) + for result in self.results] + else: + results = self.results + self.iter = iter(results) + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self.iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.csv' + + +class CsvTableBadgeExtractor(Extractor): + # Config keys + TABLE_FILE_LOCATION = 'table_file_location' + BADGE_FILE_LOCATION = 'badge_file_location' + BADGE_SEPARATOR = 'badge_separator' + + """ + An Extractor that combines Table and Badge CSVs. + """ + def init(self, conf: ConfigTree) -> None: + self.conf = conf + self.table_file_location = conf.get_string(CsvTableBadgeExtractor.TABLE_FILE_LOCATION) + self.badge_file_location = conf.get_string(CsvTableBadgeExtractor.BADGE_FILE_LOCATION) + self.badge_separator = conf.get_string(CsvTableBadgeExtractor.BADGE_SEPARATOR, default=',') + self._load_csv() + + def _get_key(self, + db: str, + cluster: str, + schema: str, + tbl: str + ) -> str: + return TableMetadata.TABLE_KEY_FORMAT.format(db=db, + cluster=cluster, + schema=schema, + tbl=tbl) + + def _load_csv(self) -> None: + with open(self.badge_file_location, 'r') as fin: + self.badges = [dict(i) for i in csv.DictReader(fin)] + # print("BADGES: " + str(self.badges)) + + parsed_badges = defaultdict(list) + for badge_dict in self.badges: + db = badge_dict['database'] + cluster = badge_dict['cluster'] + schema = badge_dict['schema'] + table_name = badge_dict['table_name'] + id = self._get_key(db, cluster, schema, table_name) + split_badges = split_badge_list(badges=badge_dict['name'], + separator=self.badge_separator) + for badge_name in split_badges: + badge = Badge(name=badge_name, category=badge_dict['category']) + parsed_badges[id].append(badge) + + with open(self.table_file_location, 'r') as fin: + tables = [dict(i) for i in csv.DictReader(fin)] + + results = [] + for table_dict in tables: + db = table_dict['database'] + cluster = table_dict['cluster'] + schema = table_dict['schema'] + table_name = table_dict['name'] + id = self._get_key(db, cluster, schema, table_name) + badges = parsed_badges[id] + + if badges is None: + badges = [] + badge_metadata = BadgeMetadata(start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=id, + badges=badges) + results.append(badge_metadata) + self._iter = iter(results) + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self._iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.csvtablebadge' + + +class CsvTableColumnExtractor(Extractor): + # Config keys + TABLE_FILE_LOCATION = 'table_file_location' + COLUMN_FILE_LOCATION = 'column_file_location' + BADGE_SEPARATOR = 'badge_separator' + + """ + An Extractor that combines Table and Column CSVs. + """ + + def init(self, conf: ConfigTree) -> None: + """ + :param conf: + """ + self.conf = conf + self.table_file_location = conf.get_string(CsvTableColumnExtractor.TABLE_FILE_LOCATION) + self.column_file_location = conf.get_string(CsvTableColumnExtractor.COLUMN_FILE_LOCATION) + self.badge_separator = conf.get_string(CsvTableBadgeExtractor.BADGE_SEPARATOR, default=',') + self._load_csv() + + def _get_key(self, + db: str, + cluster: str, + schema: str, + tbl: str + ) -> str: + return TableMetadata.TABLE_KEY_FORMAT.format(db=db, + cluster=cluster, + schema=schema, + tbl=tbl) + + def _load_csv(self) -> None: + """ + Create an iterator to execute sql. + """ + with open(self.column_file_location, 'r') as fin: + self.columns = [dict(i) for i in csv.DictReader(fin)] + + parsed_columns = defaultdict(list) + for column_dict in self.columns: + db = column_dict['database'] + cluster = column_dict['cluster'] + schema = column_dict['schema'] + table_name = column_dict['table_name'] + id = self._get_key(db, cluster, schema, table_name) + split_badges = split_badge_list(badges=column_dict['badges'], + separator=self.badge_separator) + column = ColumnMetadata( + name=column_dict['name'], + description=column_dict['description'], + col_type=column_dict['col_type'], + sort_order=int(column_dict['sort_order']), + badges=split_badges + ) + parsed_columns[id].append(column) + + # Create Table Dictionary + with open(self.table_file_location, 'r') as fin: + tables = [dict(i) for i in csv.DictReader(fin)] + + results = [] + for table_dict in tables: + db = table_dict['database'] + cluster = table_dict['cluster'] + schema = table_dict['schema'] + table_name = table_dict['name'] + id = self._get_key(db, cluster, schema, table_name) + columns = parsed_columns[id] + if columns is None: + columns = [] + table = TableMetadata(database=table_dict['database'], + cluster=table_dict['cluster'], + schema=table_dict['schema'], + name=table_dict['name'], + description=table_dict['description'], + columns=columns, + # TODO: this possibly should parse stringified booleans; + # right now it only will be false for empty strings + is_view=bool(table_dict['is_view']), + tags=table_dict['tags'] + ) + results.append(table) + self._iter = iter(results) + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self._iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.csvtablecolumn' + + +class CsvTableLineageExtractor(Extractor): + # Config keys + TABLE_LINEAGE_FILE_LOCATION = 'table_lineage_file_location' + + """ + An Extractor that creates Table Lineage between two tables + """ + + def init(self, conf: ConfigTree) -> None: + """ + :param conf: + """ + self.conf = conf + self.table_lineage_file_location = conf.get_string(CsvTableLineageExtractor.TABLE_LINEAGE_FILE_LOCATION) + self._load_csv() + + def _load_csv(self) -> None: + """ + Create an iterator to execute sql. + """ + + with open(self.table_lineage_file_location, 'r') as fin: + self.table_lineage = [dict(i) for i in csv.DictReader(fin)] + + results = [] + for lineage_dict in self.table_lineage: + source_table_key = lineage_dict['source_table_key'] + target_table_key = lineage_dict['target_table_key'] + lineage = TableLineage( + table_key=source_table_key, + downstream_deps=[target_table_key] + ) + results.append(lineage) + + self._iter = iter(results) + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self._iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.csvtablelineage' + + +class CsvColumnLineageExtractor(Extractor): + # Config keys + COLUMN_LINEAGE_FILE_LOCATION = 'column_lineage_file_location' + + """ + An Extractor that creates Column Lineage between two columns + """ + + def init(self, conf: ConfigTree) -> None: + """ + :param conf: + """ + self.conf = conf + self.column_lineage_file_location = conf.get_string(CsvColumnLineageExtractor.COLUMN_LINEAGE_FILE_LOCATION) + self._load_csv() + + def _load_csv(self) -> None: + """ + Create an iterator to execute sql. + """ + + with open(self.column_lineage_file_location, 'r') as fin: + self.column_lineage = [dict(i) for i in csv.DictReader(fin)] + + results = [] + for lineage_dict in self.column_lineage: + source_column_key = lineage_dict['source_column_key'] + target_column_key = lineage_dict['target_column_key'] + lineage = ColumnLineage( + column_key=source_column_key, + downstream_deps=[target_column_key] + ) + results.append(lineage) + + self._iter = iter(results) + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self._iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.csvcolumnlineage' diff --git a/databuilder/databuilder/extractor/dashboard/__init__.py b/databuilder/databuilder/extractor/dashboard/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/apache_superset/__init__.py b/databuilder/databuilder/extractor/dashboard/apache_superset/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/apache_superset/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_chart_extractor.py b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_chart_extractor.py new file mode 100644 index 0000000000..9bbaa5f756 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_chart_extractor.py @@ -0,0 +1,69 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import ( + Any, Dict, Iterator, Tuple, Union, +) + +from databuilder.extractor.dashboard.apache_superset.apache_superset_extractor import ( + ApacheSupersetBaseExtractor, type_fields_mapping, +) +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_query import DashboardQuery + + +class ApacheSupersetChartExtractor(ApacheSupersetBaseExtractor): + def chart_field_mappings(self) -> type_fields_mapping: + result = [ + ('chart_id', 'id', lambda x: str(x), ''), + ('chart_name', 'slice_name', None, ''), + ('chart_type', 'viz_type', None, ''), + ('chart_url', 'url', None, ''), # currently not available in superset chart api + ] + + return result + + def _get_extract_iter(self) -> Iterator[Union[DashboardQuery, DashboardChart, None]]: + ids = self._get_resource_ids('dashboard') + + data = [self._get_dashboard_details(i) for i in ids] + + for entry in data: + dashboard_id, dashboard_data = entry + + # Since Apache Superset doesn't support dashboard <> query <> chart relation we create a dummy 'bridge' + # query node so we can connect charts to a dashboard + dashboard_query_data = dict(dashboard_id=dashboard_id, + query_name='default', + query_id=self.dummy_query_id, + url='', + query_text='') + dashboard_query_data.update(**self.common_params) + + yield DashboardQuery(**dashboard_query_data) + + charts = [s.get('__Slice__') for s in dashboard_data.get('slices', [])] + + for chart in charts: + if chart: + dashboard_chart_data = self.map_fields(chart, self.chart_field_mappings()) + dashboard_chart_data.update(**{**dict(dashboard_id=dashboard_id, query_id=self.dummy_query_id), + **self.common_params}) + + yield DashboardChart(**dashboard_chart_data) + + def _get_dashboard_details(self, dashboard_id: str) -> Tuple[str, Dict[str, Any]]: + url = self.build_full_url(f'api/v1/dashboard/export?q=[{dashboard_id}]') + + _data = self.execute_query(url) + + dashboard_data = _data.get('dashboards', [dict()])[0].get('__Dashboard__', dict()) + + data = dashboard_id, dashboard_data + + return data + + @property + def dummy_query_id(self) -> str: + return '-1' diff --git a/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_extractor.py b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_extractor.py new file mode 100644 index 0000000000..b9766b452d --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_extractor.py @@ -0,0 +1,187 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +import abc +from functools import reduce +from typing import ( + Any, Dict, Iterator, List, Tuple, +) + +import requests +from dateutil import parser +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor + +type_fields_mapping = List[Tuple[str, str, Any, Any]] + + +class ApacheSupersetBaseExtractor(Extractor): + """ + Base class to create extractors pulling any dashboard metadata from Apache Superset. + """ + APACHE_SUPERSET_PROTOCOL = 'apache_superset_protocol' + APACHE_SUPERSET_HOST = 'apache_superset_host' + APACHE_SUPERSET_PORT = 'apache_superset_port' + APACHE_SUPERSET_SECURITY_SETTINGS_DICT = 'apache_superset_security_settings_dict' + APACHE_SUPERSET_PAGE_SIZE = 'apache_superset_page_size' + APACHE_SUPERSET_EXTRACT_PUBLISHED_ONLY = 'apache_superset_extract_published_only' + APACHE_SUPERSET_SECURITY_PROVIDER = 'apache_superset_security_provider' + + DASHBOARD_GROUP_NAME = 'dashboard_group_name' + DASHBOARD_GROUP_ID = 'dashboard_group_id' + DASHBOARD_GROUP_DESCRIPTION = 'dashboard_group_description' + + PRODUCT = 'product' + CLUSTER = 'cluster' + + DRIVER_TO_DATABASE_MAPPING = 'driver_to_database_mapping' + + DEFAULT_DRIVER_TO_DATABASE_MAPPING = { + 'postgresql': 'postgres', + 'mysql+pymysql': 'mysql' + } + + DATABASE_TO_CLUSTER_MAPPING = 'database_to_cluster_mapping' # map superset dbs to preferred clusters + + DEFAULT_CONFIG = ConfigFactory.from_dict({ + APACHE_SUPERSET_PROTOCOL: 'http', + APACHE_SUPERSET_HOST: 'localhost', + APACHE_SUPERSET_PORT: '8088', + APACHE_SUPERSET_PAGE_SIZE: 20, + APACHE_SUPERSET_EXTRACT_PUBLISHED_ONLY: False, + PRODUCT: 'superset', + DASHBOARD_GROUP_DESCRIPTION: '', + DRIVER_TO_DATABASE_MAPPING: DEFAULT_DRIVER_TO_DATABASE_MAPPING, + DATABASE_TO_CLUSTER_MAPPING: {} + }) + + def init(self, conf: ConfigTree) -> None: + self.conf = conf.with_fallback(ApacheSupersetBaseExtractor.DEFAULT_CONFIG) + self._extract_iter = self._get_extract_iter() + + self.authenticate() + + def get_scope(self) -> str: + return 'extractor.apache_superset' + + def extract(self) -> Any: + try: + result = next(self._extract_iter) + + return result + except StopIteration: + return None + + def authenticate(self) -> None: + security_settings = dict(self.conf.get(ApacheSupersetBaseExtractor.APACHE_SUPERSET_SECURITY_SETTINGS_DICT)) + + token = requests.post(self.build_full_url('api/v1/security/login'), + json=security_settings) + + self.token = token.json()['access_token'] + + def build_full_url(self, endpoint: str) -> str: + return f'{self.base_url}/{endpoint}' + + def execute_query(self, url: str, params: dict = {}) -> Dict: + try: + data = requests.get(url, params=params, headers={'Authorization': f'Bearer {self.token}'}) + + if data.status_code == 401: + self.authenticate() + + return self.execute_query(url, params) + else: + return data.json() + except Exception: + return {} + + @property + def base_url(self) -> str: + _protocol = self.conf.get(ApacheSupersetBaseExtractor.APACHE_SUPERSET_PROTOCOL) + _host = self.conf.get(ApacheSupersetBaseExtractor.APACHE_SUPERSET_HOST) + _port = self.conf.get(ApacheSupersetBaseExtractor.APACHE_SUPERSET_PORT) + + return f'{_protocol}://{_host}:{_port}' + + @property + def page_size(self) -> int: + return self.conf.get_int(ApacheSupersetBaseExtractor.APACHE_SUPERSET_PAGE_SIZE) + + @property + def product(self) -> str: + return self.conf.get(ApacheSupersetBaseExtractor.PRODUCT) + + @property + def cluster(self) -> str: + return self.conf.get(ApacheSupersetBaseExtractor.CLUSTER) + + @property + def extract_published_only(self) -> bool: + return self.conf.get(ApacheSupersetBaseExtractor.APACHE_SUPERSET_EXTRACT_PUBLISHED_ONLY) + + @property + def common_params(self) -> Dict[str, str]: + return dict(dashboard_group=self.conf.get(ApacheSupersetBaseExtractor.DASHBOARD_GROUP_NAME), + dashboard_group_id=self.conf.get(ApacheSupersetBaseExtractor.DASHBOARD_GROUP_ID), + dashboard_group_url=self.base_url, + dashboard_group_description=self.conf.get(ApacheSupersetBaseExtractor.DASHBOARD_GROUP_DESCRIPTION), + product=self.product, + cluster=self.cluster) + + @staticmethod + def parse_date(string_date: str) -> int: + try: + date_parsed = parser.parse(string_date) + + # date returned by superset api does not contain timezone so to be timezone safe we need to assume it's utc + if not date_parsed.tzname(): + return ApacheSupersetBaseExtractor.parse_date(f'{string_date}+0000') + + return int(date_parsed.timestamp()) + except Exception: + return 0 + + @staticmethod + def get_nested_field(input_dict: Dict, field: str) -> Any: + return reduce(lambda x, y: x.get(y, dict()), field.split('.'), input_dict) + + @staticmethod + def map_fields(data: Dict, mappings: type_fields_mapping) -> Dict: + result = dict() + + for mapping in mappings: + target_field, source_field, transform, default_value = mapping + value = ApacheSupersetBaseExtractor.get_nested_field(data, source_field) + + if transform: + value = transform(value) + + result[target_field] = value or default_value + + return result + + @abc.abstractmethod + def _get_extract_iter(self) -> Iterator[Any]: + pass + + def _get_resource_ids(self, resource: str) -> List[str]: + i = 0 + result = [] + + while True: + url = self.build_full_url(f'api/v1/{resource}?q=(page_size:{self.page_size},page:{i},order_direction:desc)') + + data = self.execute_query(url) + + ids = data.get('ids', []) + + if ids: + result += ids + i += 1 + else: + break + + return result diff --git a/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_metadata_extractor.py b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_metadata_extractor.py new file mode 100644 index 0000000000..bdaf19d6fc --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_metadata_extractor.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import ( + Any, Dict, Iterator, Union, +) + +from databuilder.extractor.dashboard.apache_superset.apache_superset_extractor import ( + ApacheSupersetBaseExtractor, type_fields_mapping, +) +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata + + +class ApacheSupersetMetadataExtractor(ApacheSupersetBaseExtractor): + def last_modified_field_mappings(self) -> type_fields_mapping: + result = [ + ('dashboard_id', 'result.id', lambda x: str(x), ''), + ('dashboard_name', 'result.dashboard_title', None, ''), + ('last_modified_timestamp', 'result.changed_on', + lambda _date: self.parse_date(_date), 0) + ] + + return result + + def dashboard_metadata_field_mappings(self) -> type_fields_mapping: + result = [ + ('dashboard_id', 'result.id', lambda x: str(x), ''), + ('dashboard_name', 'result.dashboard_title', None, ''), + ('dashboard_url', 'result.url', lambda x: self.base_url + x, ''), + ('created_timestamp', 'result.created_on', None, 0), # currently not available in superset dashboard api + ('tags', '', lambda x: [x] if x else [], []), # not available + ('description', 'result.description', None, '') # currently not available in superset dashboard api + ] + + return result + + def _get_extract_iter(self) -> Iterator[Union[DashboardMetadata, DashboardLastModifiedTimestamp, None]]: + ids = self._get_resource_ids('dashboard') + + data = [self._get_dashboard_details(i) for i in ids] + + if self.extract_published_only: + data = [d for d in data if self.get_nested_field(d, 'result.published')] + + for entry in data: + dashboard_metadata = self.map_fields(entry, self.dashboard_metadata_field_mappings()) + dashboard_metadata.update(**self.common_params) + + yield DashboardMetadata(**dashboard_metadata) + + dashboard_last_modified = self.map_fields(entry, self.last_modified_field_mappings()) + dashboard_last_modified.update(**self.common_params) + + yield DashboardLastModifiedTimestamp(**dashboard_last_modified) + + def _get_dashboard_details(self, dashboard_id: str) -> Dict[str, Any]: + url = self.build_full_url(f'api/v1/dashboard/{dashboard_id}') + + data = self.execute_query(url) + + return data diff --git a/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_table_extractor.py b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_table_extractor.py new file mode 100644 index 0000000000..49dc2df298 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/apache_superset/apache_superset_table_extractor.py @@ -0,0 +1,100 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +from functools import lru_cache +from typing import ( + Any, Dict, Iterator, Union, +) + +from sqlalchemy.engine.url import make_url + +from databuilder.extractor.dashboard.apache_superset.apache_superset_extractor import ApacheSupersetBaseExtractor +from databuilder.models.dashboard.dashboard_table import DashboardTable +from databuilder.models.table_metadata import TableMetadata + + +class ApacheSupersetTableExtractor(ApacheSupersetBaseExtractor): + def _get_extract_iter(self) -> Iterator[Union[DashboardTable, None]]: + dashboards: Dict[str, set] = dict() + + ids = self._get_resource_ids('dataset') + + data = [(self._get_dataset_details(i), self._get_dataset_related_objects(i)) for i in ids] + + for entry in data: + dataset_details, dataset_objects = entry + + database_id = self.get_nested_field(dataset_details, 'result.database.id') + + if database_id: + database_details = self._get_database_details(database_id) + + sql = self.get_nested_field(dataset_details, 'result.sql') or '' + + # if sql exists then table_name cannot be associated with physical table in db + if not len(sql) > 0: + uri = self.get_nested_field(database_details, 'result.sqlalchemy_uri') + database_spec = make_url(uri) + + db = self.driver_mapping.get(database_spec.drivername, database_spec.drivername) + schema = database_spec.database + + cluster = self.cluster_mapping.get(str(database_id), self.cluster) + tbl = self.get_nested_field(dataset_details, 'result.table_name') + + table_key = TableMetadata.TABLE_KEY_FORMAT.format(db=db, + cluster=cluster, + schema=schema, + tbl=tbl) + + for dashboard in dataset_objects.get('dashboards', dict()).get('result', []): + dashboard_id = str(dashboard.get('id')) + + if not dashboards.get(dashboard_id): + dashboards[dashboard_id] = set() + + dashboards[dashboard_id].add(table_key) + else: + pass + else: + continue + + for dashboard_id, table_keys in dashboards.items(): + table_metadata: Dict[str, Any] = {'dashboard_id': dashboard_id, 'table_ids': table_keys} + + table_metadata.update(**self.common_params) + + result = DashboardTable(**table_metadata) + + yield result + + def _get_dataset_details(self, dataset_id: str) -> Dict[str, Any]: + url = self.build_full_url(f'api/v1/dataset/{dataset_id}') + + data = self.execute_query(url) + + return data + + def _get_dataset_related_objects(self, dataset_id: str) -> Dict[str, Any]: + url = self.build_full_url(f'api/v1/dataset/{dataset_id}/related_objects') + + data = self.execute_query(url) + + return data + + @lru_cache(maxsize=512) + def _get_database_details(self, database_id: str) -> Dict[str, Any]: + url = self.build_full_url(f'api/v1/database/{database_id}') + + data = self.execute_query(url) + + return data + + @property + def driver_mapping(self) -> Dict[str, str]: + return self.conf.get(self.DRIVER_TO_DATABASE_MAPPING) + + @property + def cluster_mapping(self) -> Dict[str, str]: + return self.conf.get(self.DATABASE_TO_CLUSTER_MAPPING) diff --git a/databuilder/databuilder/extractor/dashboard/databricks_sql/__init__.py b/databuilder/databuilder/extractor/dashboard/databricks_sql/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/databricks_sql/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_extractor.py b/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_extractor.py new file mode 100644 index 0000000000..1cc3eb3dd4 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_extractor.py @@ -0,0 +1,186 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterator, Optional, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.databricks_sql.databricks_sql_dashboard_utils import ( + DatabricksSQLPaginatedRestApiQuery, generate_dashboard_description, get_text_widgets, get_visualization_widgets, + sort_widgets, +) +from databuilder.extractor.restapi.rest_api_extractor import REST_API_QUERY, RestAPIExtractor +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.dashboard.dashboard_owner import DashboardOwner +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.rest_api.base_rest_api_query import EmptyRestApiQuerySeed +from databuilder.rest_api.rest_api_query import RestApiQuery +from databuilder.transformer.base_transformer import ChainedTransformer +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME as TS_FIELD_NAME, TimestampStringToEpoch + + +class DatabricksSQLDashboardExtractor(Extractor): + """ + An extractor for retrieving dashboards, queries, and visualizations + from Databricks SQL (https://databricks.com/product/databricks-sql) + """ + + DATABRICKS_HOST_KEY = "databricks_host" + DATABRICKS_API_TOKEN_KEY = "databricks_api_token" + + PRODUCT = "databricks-sql" + DASHBOARD_GROUP_ID = "databricks-sql" + DASHBOARD_GROUP_NAME = "Databricks SQL" + + def init(self, conf: ConfigTree) -> None: + # Required configuration + self._databricks_host = conf.get_string( + DatabricksSQLDashboardExtractor.DATABRICKS_HOST_KEY + ) + self._databricks_api_token = conf.get_string( + DatabricksSQLDashboardExtractor.DATABRICKS_API_TOKEN_KEY + ) + + # NOTE: The dashboards api is currently in preview. When it gets moved out of preview + # this will break and it will need to be changed + self._databricks_sql_dashboards_api_base = ( + f"{self._databricks_host}/api/2.0/preview/sql/dashboards" + ) + + self._extractor = self._build_extractor() + self._transformer = self._build_transformer() + self._extract_iter: Optional[Iterator[Any]] = None + + def _get_databrick_request_headers(self) -> Dict[str, str]: + return { + "Authorization": f"Bearer {self._databricks_api_token}", + } + + def _get_extract_iter(self) -> Iterator[Any]: + while True: + record = self._extractor.extract() + if not record: + break + + record = next(self._transformer.transform(record=record), None) + dashboard_identity_data = { + "dashboard_group_id": DatabricksSQLDashboardExtractor.DASHBOARD_GROUP_ID, + "dashboard_id": record["dashboard_id"], + "product": "databricks-sql", + } + + dashboard_data = { + "dashboard_group": DatabricksSQLDashboardExtractor.DASHBOARD_GROUP_NAME, + "dashboard_name": record["dashboard_name"], + "dashboard_url": f"{self._databricks_host}/sql/dashboards/{record['dashboard_id']}", + "dashboard_group_url": self._databricks_host, + "created_timestamp": record["created_timestamp"], + "tags": record["tags"], + } + + dashboard_owner_data = {"email": record["user"]["email"]} + dashboard_owner_data.update(dashboard_identity_data) + yield DashboardOwner(**dashboard_owner_data) + + dashboard_last_modified_data = { + "last_modified_timestamp": record["last_modified_timestamp"], + } + dashboard_last_modified_data.update(dashboard_identity_data) + yield DashboardLastModifiedTimestamp(**dashboard_last_modified_data) + + if "widgets" in record: + widgets = sort_widgets(record["widgets"]) + text_widgets = get_text_widgets(widgets) + viz_widgets = get_visualization_widgets(widgets) + dashboard_data["description"] = generate_dashboard_description( + text_widgets, viz_widgets + ) + + for viz in viz_widgets: + dashboard_query_data = { + "query_id": str(viz.query_id), + "query_name": viz.query_name, + "url": self._databricks_host + viz.query_relative_url, + "query_text": viz.raw_query, + } + dashboard_query_data.update(dashboard_identity_data) + yield DashboardQuery(**dashboard_query_data) + + dashboard_chart_data = { + "query_id": str(viz.query_id), + "chart_id": str(viz.visualization_id), + "chart_name": viz.visualization_name, + "chart_type": viz.visualization_type, + } + dashboard_chart_data.update(dashboard_identity_data) + yield DashboardChart(**dashboard_chart_data) + + dashboard_data.update(dashboard_identity_data) + yield DashboardMetadata(**dashboard_data) + + def extract(self) -> Any: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _build_extractor(self) -> RestAPIExtractor: + extractor = RestAPIExtractor() + rest_api_extractor_conf = ConfigFactory.from_dict( + {REST_API_QUERY: self._build_restapi_query()} + ) + extractor.init(rest_api_extractor_conf) + return extractor + + def _build_transformer(self) -> ChainedTransformer: + transformers = [] + # transform timestamps from ISO to unix epoch + ts_transformer_1 = TimestampStringToEpoch() + ts_transformer_1.init( + ConfigFactory.from_dict({TS_FIELD_NAME: "created_timestamp"}) + ) + transformers.append(ts_transformer_1) + + ts_transformer_2 = TimestampStringToEpoch() + ts_transformer_2.init( + ConfigFactory.from_dict({TS_FIELD_NAME: "last_modified_timestamp"}) + ) + transformers.append(ts_transformer_2) + + return ChainedTransformer(transformers=transformers) + + def _build_restapi_query(self) -> RestApiQuery: + databricks_sql_dashboard_query = DatabricksSQLPaginatedRestApiQuery( + query_to_join=EmptyRestApiQuerySeed(), + url=self._databricks_sql_dashboards_api_base, + params={"headers": self._get_databrick_request_headers()}, + json_path="results[*].[id,name,tags,updated_at,created_at,user]", + field_names=[ + "dashboard_id", + "dashboard_name", + "tags", + "last_modified_timestamp", + "created_timestamp", + "user", + ], + skip_no_results=True, + ) + + return RestApiQuery( + query_to_join=databricks_sql_dashboard_query, + url=f"{self._databricks_sql_dashboards_api_base}/{{dashboard_id}}", + params={"headers": self._get_databrick_request_headers()}, + json_path="widgets", + field_names=["widgets"], + skip_no_result=True, + ) + + def get_scope(self) -> str: + return "extractor.databricks_sql_extractor" diff --git a/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_utils.py b/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_utils.py new file mode 100644 index 0000000000..c101976780 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/databricks_sql/databricks_sql_dashboard_utils.py @@ -0,0 +1,167 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterable, List, Tuple, +) + +from databuilder.rest_api.rest_api_query import RestApiQuery + + +class DatabricksSQLVisualizationWidget: + """ + A visualization widget in a Databricks SQL dashboard. + These are mapped 1:1 with queries, and can be of various types, e.g.: + CHART, TABLE, PIVOT, etc. + The query name acts like a title for the widget on the dashboard. + """ + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + @property + def raw_query(self) -> str: + return self._data["visualization"]["query"]["query"] + + @property + def data_source_id(self) -> int: + return self._data["visualization"]["query"]["data_source_id"] + + @property + def query_id(self) -> int: + return self._data["visualization"]["query"]["id"] + + @property + def query_relative_url(self) -> str: + return f"/queries/{self.query_id}" + + @property + def query_name(self) -> str: + return self._data["visualization"]["query"]["name"] + + @property + def visualization_id(self) -> int: + return self._data["visualization"]["id"] + + @property + def visualization_name(self) -> str: + return self._data["visualization"]["name"] + + @property + def visualization_type(self) -> str: + return self._data["visualization"]["type"] + + +class DatabricksSQLTextWidget: + """ + A textbox in a Databricks SQL dashboard. + It pretty much just contains a single text property (Markdown). + """ + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + @property + def text(self) -> str: + return self._data["text"] + + +def sort_widgets(widgets: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Sort raw widget data (as returned from the API) according to the position + of the widgets in the dashboard (top to bottom, left to right) + Redash does not return widgets in order of their position, + so we do this to ensure that we look at widgets in a sensible order. + """ + + def row_and_col(widget: Dict[str, Any]) -> Tuple[Any, Any]: + # these entities usually but not always have explicit rows and cols + pos = widget["options"].get("position", {}) + return (pos.get("row", 0), pos.get("col", 0)) + + return sorted(widgets, key=row_and_col) + + +def get_text_widgets( + widgets: Iterable[Dict[str, Any]] +) -> List[DatabricksSQLTextWidget]: + """ + From the raw set of widget data returned from the API, filter down + to text widgets and return them as a list of `RedashTextWidget` + """ + + return [ + DatabricksSQLTextWidget(widget) + for widget in widgets + if "text" in widget and "visualization" not in widget + ] + + +def get_visualization_widgets( + widgets: Iterable[Dict[str, Any]] +) -> List[DatabricksSQLVisualizationWidget]: + """ + From the raw set of widget data returned from the API, filter down + to visualization widgets and return them as a list of `RedashVisualizationWidget` + """ + + return [ + DatabricksSQLVisualizationWidget(widget) + for widget in widgets + if "visualization" in widget + ] + + +def get_auth_headers(api_key: str) -> Dict[str, str]: + return {"Authorization": f"Bearer {api_key}"} + + +def generate_dashboard_description( + text_widgets: List[DatabricksSQLTextWidget], + viz_widgets: List[DatabricksSQLVisualizationWidget], +) -> str: + """ + Redash doesn't have dashboard descriptions, so we'll make our own. + If there exist any text widgets, concatenate them, + and use this text as the description for this dashboard. + If not, put together a list of query names. + If all else fails, this looks like an empty dashboard. + """ + + if len(text_widgets) > 0: + return "\n\n".join([w.text for w in text_widgets]) + elif len(viz_widgets) > 0: + query_list = "\n".join([f"- {v.query_name}" for v in set(viz_widgets)]) + return "A dashboard containing the following queries:\n\n" + query_list + + return "This dashboard appears to be empty!" + + +class DatabricksSQLPaginatedRestApiQuery(RestApiQuery): + """ + Paginated Databricks SQL API queries + """ + + def __init__(self, **kwargs: Any) -> None: + super(DatabricksSQLPaginatedRestApiQuery, self).__init__(**kwargs) + if "params" not in self._params: + self._params["params"] = {} + self._params["params"]["page"] = 1 + + def _total_records(self, res: Dict[str, Any]) -> int: + return res["count"] + + def _max_record_on_page(self, res: Dict[str, Any]) -> int: + return res["page_size"] * res["page"] + + def _next_page(self, res: Dict[str, Any]) -> int: + return res["page"] + 1 + + def _post_process(self, response: Any) -> None: + parsed = response.json() + + if self._max_record_on_page(parsed) >= self._total_records(parsed): + self._more_pages = False + else: + self._params["params"]["page"] = self._next_page(parsed) + self._more_pages = True diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/__init__.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_charts_batch_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_charts_batch_extractor.py new file mode 100644 index 0000000000..91c790a46b --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_charts_batch_extractor.py @@ -0,0 +1,88 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_constants import ORGANIZATION +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.base_rest_api_query import RestApiQuerySeed +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.rest_api.rest_api_query import RestApiQuery +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardChartsBatchExtractor(Extractor): + """ + Mode dashboard chart extractor leveraging batch / discovery endpoint. + The detail could be found in https://mode.com/help/articles/discovery-api/#list-charts-for-an-organization + """ + # config to include the charts from all space + INCLUDE_ALL_SPACE = 'include_all_space' + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor( + restapi_query=restapi_query, + conf=self._conf + ) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_chart.DashboardChart'}))) + self._transformer = dict_to_model_transformer + + def extract(self) -> Any: + + record = self._extractor.extract() + if not record: + return None + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_chart_batch' + + def _build_restapi_query(self) -> RestApiQuery: + """ + Build a paginated REST API based on Mode discovery API + :return: + """ + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + seed_record = [{ + 'organization': self._conf.get_string(ORGANIZATION), + 'is_active': None, + 'updated_at': None, + 'do_not_update_empty_attribute': True, + }] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + chart_url_template = 'http://app.mode.com/batch/{organization}/charts' + if self._conf.get_bool(ModeDashboardChartsBatchExtractor.INCLUDE_ALL_SPACE, default=False): + chart_url_template += '?include_spaces=all' + json_path = '(charts[*].[space_token,report_token,query_token,token,chart_title,chart_type])' + field_names = ['dashboard_group_id', + 'dashboard_id', + 'query_id', + 'chart_id', + 'chart_name', + 'chart_type'] + max_record_size = 1000 + chart_batch_query = ModePaginatedRestApiQuery(query_to_join=seed_query, + url=chart_url_template, + params=params, + json_path=json_path, + pagination_json_path=json_path, + field_names=field_names, + skip_no_result=True, + max_record_size=max_record_size) + return chart_batch_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_constants.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_constants.py new file mode 100644 index 0000000000..0b82ddf42f --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_constants.py @@ -0,0 +1,10 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +ORGANIZATION = 'organization' +MODE_ACCESS_TOKEN = 'mode_user_token' +MODE_PASSWORD_TOKEN = 'mode_password_token' + +# this token is needed to access batch discover endpoint +# e.g https://mode.com/developer/discovery-api/introduction/ +MODE_BEARER_TOKEN = 'mode_bearer_token' diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_executions_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_executions_extractor.py new file mode 100644 index 0000000000..f5d93ca937 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_executions_extractor.py @@ -0,0 +1,85 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, List + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME, TimestampStringToEpoch + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardExecutionsExtractor(Extractor): + """ + A Extractor that extracts run (execution) status and timestamp. + + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor( + restapi_query=restapi_query, + conf=self._conf + ) + + # Payload from RestApiQuery has timestamp which is ISO8601. Here we are using TimestampStringToEpoch to + # transform into epoch and then using DictToModel to convert Dictionary to Model + transformers: List[Transformer] = [] + timestamp_str_to_epoch_transformer = TimestampStringToEpoch() + timestamp_str_to_epoch_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, timestamp_str_to_epoch_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({FIELD_NAME: 'execution_timestamp', }))) + + transformers.append(timestamp_str_to_epoch_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_execution.DashboardExecution'}))) + transformers.append(dict_to_model_transformer) + + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_execution' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard last execution. + :return: A RestApiQuery that provides Mode Dashboard execution (run) + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + url = 'https://app.mode.com/batch/{organization}/reports' + json_path = 'reports[*].[token, space_token, last_run_at, last_run_state]' + field_names = ['dashboard_id', 'dashboard_group_id', 'execution_timestamp', 'execution_state'] + max_record_size = 1000 + pagination_json_path = 'reports[*]' + last_execution_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path) + + return last_execution_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_extractor.py new file mode 100644 index 0000000000..f0ea135018 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_extractor.py @@ -0,0 +1,124 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, List + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.rest_api.query_merger import QueryMerger +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.template_variable_substitution_transformer import ( + FIELD_NAME as VAR_FIELD_NAME, TEMPLATE, TemplateVariableSubstitutionTransformer, +) +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME, TimestampStringToEpoch + +LOGGER = logging.getLogger(__name__) + +# a list of space tokens that we want to skip indexing +DASHBOARD_GROUP_IDS_TO_SKIP = 'dashboard_group_ids_to_skip' + + +class ModeDashboardExtractor(Extractor): + """ + A Extractor that extracts core metadata on Mode dashboard. https://app.mode.com/ + It extracts list of reports that consists of: + Dashboard group name (Space name) + Dashboard group id (Space token) + Dashboard group description (Space description) + Dashboard name (Report name) + Dashboard id (Report token) + Dashboard description (Report description) + + Other information such as report run, owner, chart name, query name is in separate extractor. + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + self.dashboard_group_ids_to_skip = self._conf.get_list(DASHBOARD_GROUP_IDS_TO_SKIP, []) + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor(restapi_query=restapi_query, + conf=self._conf) + + # Payload from RestApiQuery has timestamp which is ISO8601. Here we are using TimestampStringToEpoch to + # transform into epoch and then using DictToModel to convert Dictionary to Model + transformers: List[Transformer] = [] + timestamp_str_to_epoch_transformer = TimestampStringToEpoch() + timestamp_str_to_epoch_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, timestamp_str_to_epoch_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({FIELD_NAME: 'created_timestamp', }))) + + transformers.append(timestamp_str_to_epoch_transformer) + + dashboard_group_url_transformer = TemplateVariableSubstitutionTransformer() + dashboard_group_url_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dashboard_group_url_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({VAR_FIELD_NAME: 'dashboard_group_url', + TEMPLATE: 'https://app.mode.com/{organization}/spaces/{dashboard_group_id}'}))) + + transformers.append(dashboard_group_url_transformer) + + dashboard_url_transformer = TemplateVariableSubstitutionTransformer() + dashboard_url_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dashboard_url_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({VAR_FIELD_NAME: 'dashboard_url', + TEMPLATE: 'https://app.mode.com/{organization}/reports/{dashboard_id}'}))) + transformers.append(dashboard_url_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_metadata.DashboardMetadata'}))) + transformers.append(dict_to_model_transformer) + + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + + # determine whether we want to skip these records + while record and record.get('dashboard_group_id') in self.dashboard_group_ids_to_skip: + record = self._extractor.extract() + + if not record: + return None + + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard metadata + :return: A RestApiQuery that provides Mode Dashboard metadata + """ + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + url = 'https://app.mode.com/batch/{organization}/reports' + json_path = 'reports[*].[token, name, description, created_at, space_token]' + field_names = ['dashboard_id', 'dashboard_name', 'description', 'created_timestamp', 'dashboard_group_id'] + max_record_size = 1000 + pagination_json_path = 'reports[*]' + + spaces_query = ModeDashboardUtils.get_spaces_query_api(conf=self._conf) + query_merger = QueryMerger(query_to_merge=spaces_query, merge_key='dashboard_group_id') + + reports_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path, + query_merger=query_merger) + + return reports_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_modified_timestamp_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_modified_timestamp_extractor.py new file mode 100644 index 0000000000..63a1b3e6d1 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_modified_timestamp_extractor.py @@ -0,0 +1,65 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_executions_extractor import ( + ModeDashboardExecutionsExtractor, +) +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME, TimestampStringToEpoch + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardLastModifiedTimestampExtractor(ModeDashboardExecutionsExtractor): + """ + A Extractor that extracts Mode dashboard's last modified timestamp. + + """ + + def __init__(self) -> None: + super(ModeDashboardLastModifiedTimestampExtractor, self).__init__() + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback( + ConfigFactory.from_dict({ + STATIC_RECORD_DICT: {'product': 'mode'}, + f'{DictToModel().get_scope()}.{MODEL_CLASS}': + 'databuilder.models.dashboard.dashboard_last_modified.DashboardLastModifiedTimestamp', + f'{TimestampStringToEpoch().get_scope()}.{FIELD_NAME}': + 'last_modified_timestamp' + }) + ) + super(ModeDashboardLastModifiedTimestampExtractor, self).init(conf) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_last_modified_timestamp_execution' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard last modified timestamp + :return: A RestApiQuery that provides Mode Dashboard last successful execution (run) + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + url = 'https://app.mode.com/batch/{organization}/reports' + json_path = 'reports[*].[token, space_token, edited_at]' + field_names = ['dashboard_id', 'dashboard_group_id', 'last_modified_timestamp'] + max_record_size = 1000 + pagination_json_path = 'reports[*]' + last_modified_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path) + + return last_modified_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_successful_executions_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_successful_executions_extractor.py new file mode 100644 index 0000000000..d477750387 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_last_successful_executions_extractor.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_executions_extractor import ( + ModeDashboardExecutionsExtractor, +) +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.models.dashboard.dashboard_execution import DashboardExecution +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardLastSuccessfulExecutionExtractor(ModeDashboardExecutionsExtractor): + """ + A Extractor that extracts Mode dashboard's last successful run (execution) timestamp. + + """ + + def __init__(self) -> None: + super(ModeDashboardLastSuccessfulExecutionExtractor, self).__init__() + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback( + ConfigFactory.from_dict({ + STATIC_RECORD_DICT: {'product': 'mode', + 'execution_state': 'succeeded', + 'execution_id': DashboardExecution.LAST_SUCCESSFUL_EXECUTION_ID} + }) + ) + super(ModeDashboardLastSuccessfulExecutionExtractor, self).init(conf) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_last_successful_execution' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard last successful execution metadata. + :return: A RestApiQuery that provides Mode Dashboard last successful execution (run) + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + url = 'https://app.mode.com/batch/{organization}/reports' + json_path = 'reports[*].[token, space_token, last_successfully_run_at]' + field_names = ['dashboard_id', 'dashboard_group_id', 'execution_timestamp'] + max_record_size = 1000 + pagination_json_path = 'reports[*]' + last_successful_run_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path) + + return last_successful_run_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_owner_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_owner_extractor.py new file mode 100644 index 0000000000..067f20e64d --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_owner_extractor.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.extractor.restapi.rest_api_extractor import MODEL_CLASS +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardOwnerExtractor(Extractor): + """ + An Extractor that extracts Dashboard owner. + + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor( + restapi_query=restapi_query, + conf=self._conf.with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_owner.DashboardOwner', } + ) + ) + ) + + def extract(self) -> Any: + return self._extractor.extract() + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_owner' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard owner + :return: A RestApiQuery that provides Mode Dashboard owner + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + url = 'https://app.mode.com/batch/{organization}/reports' + json_path = 'reports[*].[token, space_token, creator_email]' + field_names = ['dashboard_id', 'dashboard_group_id', 'email'] + max_record_size = 1000 + pagination_json_path = 'reports[*]' + creator_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path) + + return creator_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_queries_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_queries_extractor.py new file mode 100644 index 0000000000..b80b02750d --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_queries_extractor.py @@ -0,0 +1,100 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, List + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.regex_str_replace_transformer import ( + ATTRIBUTE_NAME, REGEX_REPLACE_TUPLE_LIST, RegexStrReplaceTransformer, +) +from databuilder.transformer.template_variable_substitution_transformer import ( + FIELD_NAME, TEMPLATE, TemplateVariableSubstitutionTransformer, +) + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardQueriesExtractor(Extractor): + """ + A Extractor that extracts Query information + + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor( + restapi_query=restapi_query, + conf=self._conf + ) + + # Constructing URL using several ID via TemplateVariableSubstitutionTransformer + transformers: List[Transformer] = [] + variable_substitution_transformer = TemplateVariableSubstitutionTransformer() + variable_substitution_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, + variable_substitution_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({FIELD_NAME: 'url', + TEMPLATE: 'https://app.mode.com/{organization}' + '/reports/{dashboard_id}/queries/{query_id}'}))) + + transformers.append(variable_substitution_transformer) + + # Escape backslash as it breaks Cypher statement. + replace_transformer = RegexStrReplaceTransformer() + replace_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, replace_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {REGEX_REPLACE_TUPLE_LIST: [('\\', '\\\\')], ATTRIBUTE_NAME: 'query_text'}))) + transformers.append(replace_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_query.DashboardQuery'}))) + transformers.append(dict_to_model_transformer) + + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_query' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query to get Mode Dashboard queries + :return: A RestApiQuery that provides Mode Dashboard execution (run) + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Queries + # https://mode.com/developer/discovery-api/analytics/queries/ + url = 'https://app.mode.com/batch/{organization}/queries' + json_path = 'queries[*].[report_token, space_token, token, name, raw_query]' + field_names = ['dashboard_id', 'dashboard_group_id', 'query_id', 'query_name', 'query_text'] + max_record_size = 1000 + pagination_json_path = 'queries[*]' + query_names_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=url, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, max_record_size=max_record_size, + pagination_json_path=pagination_json_path) + + return query_names_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_usage_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_usage_extractor.py new file mode 100644 index 0000000000..27e0b558c9 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_usage_extractor.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery +from databuilder.rest_api.query_merger import QueryMerger + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardUsageExtractor(Extractor): + """ + A Extractor that extracts Mode dashboard's accumulated view count + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor(restapi_query=restapi_query, + conf=self._conf) + + def extract(self) -> Any: + return self._extractor.extract() + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_usage' + + def _build_restapi_query(self) -> ModePaginatedRestApiQuery: + """ + Build REST API Query. To get Mode Dashboard usage, it needs to call three discovery APIs ( + spaces API, reports API and report stats API). + :return: A RestApiQuery that provides Mode Dashboard metadata + """ + + seed_query = ModeDashboardUtils.get_seed_query(conf=self._conf) + params = ModeDashboardUtils.get_auth_params(conf=self._conf, discover_auth=True) + + # Reports + # https://mode.com/developer/discovery-api/analytics/reports/ + reports_url = 'https://app.mode.com/batch/{organization}/reports' + reports_json_path = 'reports[*].[token, space_token]' + reports_field_names = ['dashboard_id', 'dashboard_group_id'] + reports_max_record_size = 1000 + reports_pagination_json_path = 'reports[*]' + spaces_query = ModeDashboardUtils.get_spaces_query_api(conf=self._conf) + spaces_query_merger = QueryMerger(query_to_merge=spaces_query, merge_key='dashboard_group_id') + reports_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=reports_url, params=params, + json_path=reports_json_path, field_names=reports_field_names, + skip_no_result=True, max_record_size=reports_max_record_size, + pagination_json_path=reports_pagination_json_path, + query_merger=spaces_query_merger) + + # https://mode.com/developer/discovery-api/analytics/report-stats/ + stats_url = 'https://app.mode.com/batch/{organization}/report_stats' + stats_json_path = 'report_stats[*].[report_token, view_count]' + stats_field_names = ['dashboard_id', 'accumulated_view_count'] + stats_max_record_size = 1000 + stats_pagination_json_path = 'report_stats[*]' + reports_query_merger = QueryMerger(query_to_merge=reports_query, merge_key='dashboard_id') + report_stats_query = ModePaginatedRestApiQuery(query_to_join=seed_query, url=stats_url, params=params, + json_path=stats_json_path, field_names=stats_field_names, + skip_no_result=True, max_record_size=stats_max_record_size, + pagination_json_path=stats_pagination_json_path, + query_merger=reports_query_merger) + + return report_stats_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_user_extractor.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_user_extractor.py new file mode 100644 index 0000000000..45fd4f52d6 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_user_extractor.py @@ -0,0 +1,107 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, List + +from pyhocon import ConfigFactory, ConfigTree +from requests.auth import HTTPBasicAuth + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_constants import ( + MODE_ACCESS_TOKEN, MODE_PASSWORD_TOKEN, ORGANIZATION, +) +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_utils import ModeDashboardUtils +from databuilder.rest_api.base_rest_api_query import RestApiQuerySeed +from databuilder.rest_api.rest_api_failure_handlers import HttpFailureSkipOnStatus +from databuilder.rest_api.rest_api_query import RestApiQuery +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.remove_field_transformer import FIELD_NAMES, RemoveFieldTransformer + +LOGGER = logging.getLogger(__name__) + + +class ModeDashboardUserExtractor(Extractor): + """ + An Extractor that extracts all Mode Dashboard user and add mode_user_id attribute to User model. + """ + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + + restapi_query = self._build_restapi_query() + self._extractor = ModeDashboardUtils.create_mode_rest_api_extractor( + restapi_query=restapi_query, + conf=self._conf + ) + + # Remove all unnecessary fields because User model accepts all attributes and push it to Neo4j. + transformers: List[Transformer] = [] + + remove_fields_transformer = RemoveFieldTransformer() + remove_fields_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, remove_fields_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {FIELD_NAMES: ['organization', 'mode_user_resource_path', 'product']}))) + transformers.append(remove_fields_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.user.User'}))) + transformers.append(dict_to_model_transformer) + + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.mode_dashboard_owner' + + def _build_restapi_query(self) -> RestApiQuery: + """ + Build REST API Query. To get Mode Dashboard owner, it needs to call three APIs (spaces API, reports + API, and user API) joining together. + :return: A RestApiQuery that provides Mode Dashboard owner + """ + + # Seed query record for next query api to join with + seed_record = [{ + 'organization': self._conf.get_string(ORGANIZATION), + 'is_active': None, + 'updated_at': None, + 'do_not_update_empty_attribute': True, + }] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + # memberships + # https://mode.com/developer/api-reference/management/organization-memberships/#listMemberships + memberships_url_template = 'https://app.mode.com/api/{organization}/memberships' + params = {'auth': HTTPBasicAuth(self._conf.get_string(MODE_ACCESS_TOKEN), + self._conf.get_string(MODE_PASSWORD_TOKEN))} + + json_path = '(_embedded.memberships[*].member_username) | (_embedded.memberships[*]._links.user.href)' + field_names = ['mode_user_id', 'mode_user_resource_path'] + mode_user_ids_query = RestApiQuery(query_to_join=seed_query, url=memberships_url_template, params=params, + json_path=json_path, field_names=field_names, + skip_no_result=True, json_path_contains_or=True) + + # https://mode.com/developer/api-reference/management/users/ + user_url_template = 'https://app.mode.com{mode_user_resource_path}' + + json_path = 'email' + field_names = ['email'] + failure_handler = HttpFailureSkipOnStatus(status_codes_to_skip={404}) + mode_user_email_query = RestApiQuery(query_to_join=mode_user_ids_query, url=user_url_template, + params=params, json_path=json_path, field_names=field_names, + skip_no_result=True, can_skip_failure=failure_handler.can_skip_failure) + + return mode_user_email_query diff --git a/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_utils.py b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_utils.py new file mode 100644 index 0000000000..aae4465eea --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/mode_analytics/mode_dashboard_utils.py @@ -0,0 +1,101 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict + +from pyhocon import ConfigFactory, ConfigTree +from requests.auth import HTTPBasicAuth + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_constants import ( + MODE_ACCESS_TOKEN, MODE_BEARER_TOKEN, MODE_PASSWORD_TOKEN, ORGANIZATION, +) +from databuilder.extractor.restapi.rest_api_extractor import ( + REST_API_QUERY, STATIC_RECORD_DICT, RestAPIExtractor, +) +from databuilder.rest_api.base_rest_api_query import BaseRestApiQuery, RestApiQuerySeed +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery + + +class ModeDashboardUtils(object): + + @staticmethod + def get_seed_query(conf: ConfigTree) -> BaseRestApiQuery: + # Seed query record for next query api to join with + seed_record = [{'organization': conf.get_string(ORGANIZATION)}] + seed_query = RestApiQuerySeed(seed_record=seed_record) + return seed_query + + @staticmethod + def get_spaces_query_api(conf: ConfigTree) -> BaseRestApiQuery: + """ + Provides RestApiQuerySeed where it will provides iterator of dictionaries as records where dictionary keys are + organization, dashboard_group_id, dashboard_group and dashboard_group_description + :param conf: + :return: + """ + + # https://mode.com/developer/discovery-api/analytics/spaces + spaces_url_template = 'https://app.mode.com/batch/{organization}/spaces' + + # Seed query record for next query api to join with + seed_query = ModeDashboardUtils.get_seed_query(conf=conf) + + # mode_bearer_token must be provided in the conf + # the token is required to access discovery endpoint + # https://mode.com/developer/discovery-api/introduction/ + params = ModeDashboardUtils.get_auth_params(conf=conf, discover_auth=True) + + json_path = 'spaces[*].[token,name,description]' + field_names = ['dashboard_group_id', 'dashboard_group', 'dashboard_group_description'] + + # based on https://mode.com/developer/discovery-api/analytics/spaces/#listSpacesForAccount + pagination_json_path = 'spaces[*]' + max_per_page = 1000 + spaces_query = ModePaginatedRestApiQuery(pagination_json_path=pagination_json_path, + max_record_size=max_per_page, query_to_join=seed_query, + url=spaces_url_template, params=params, json_path=json_path, + field_names=field_names) + + return spaces_query + + @staticmethod + def get_auth_params(conf: ConfigTree, discover_auth: bool = False) -> Dict[str, Any]: + if discover_auth: + # Mode discovery API needs custom token set in header + # https://mode.com/developer/discovery-api/introduction/ + params = { + "headers": { + "Authorization": conf.get_string(MODE_BEARER_TOKEN), + } + } # type: Dict[str, Any] + else: + params = { + 'auth': HTTPBasicAuth(conf.get_string(MODE_ACCESS_TOKEN), + conf.get_string(MODE_PASSWORD_TOKEN) + ) + } + return params + + @staticmethod + def create_mode_rest_api_extractor(restapi_query: BaseRestApiQuery, + conf: ConfigTree + ) -> RestAPIExtractor: + """ + Creates RestAPIExtractor. Note that RestAPIExtractor is already initialized + :param restapi_query: + :param conf: + :return: RestAPIExtractor. Note that RestAPIExtractor is already initialized + """ + extractor = RestAPIExtractor() + rest_api_extractor_conf = \ + Scoped.get_scoped_conf(conf, extractor.get_scope())\ + .with_fallback(conf)\ + .with_fallback(ConfigFactory.from_dict({REST_API_QUERY: restapi_query, + STATIC_RECORD_DICT: {'product': 'mode'} + } + ) + ) + + extractor.init(conf=rest_api_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/dashboard/redash/__init__.py b/databuilder/databuilder/extractor/dashboard/redash/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/redash/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_extractor.py b/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_extractor.py new file mode 100644 index 0000000000..24a7242537 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_extractor.py @@ -0,0 +1,267 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import ( + Any, Dict, Iterator, Optional, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.redash.redash_dashboard_utils import ( + RedashPaginatedRestApiQuery, generate_dashboard_description, get_auth_headers, get_text_widgets, + get_visualization_widgets, sort_widgets, +) +from databuilder.extractor.restapi.rest_api_extractor import REST_API_QUERY, RestAPIExtractor +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.dashboard.dashboard_owner import DashboardOwner +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.models.dashboard.dashboard_table import DashboardTable +from databuilder.models.table_metadata import TableMetadata +from databuilder.rest_api.base_rest_api_query import EmptyRestApiQuerySeed +from databuilder.rest_api.rest_api_query import RestApiQuery +from databuilder.transformer.base_transformer import ChainedTransformer +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME as TS_FIELD_NAME, TimestampStringToEpoch + + +class TableRelationData: + """ + This is sort of like a stripped down version of `TableMetadata`. + It is used as the type returned by the (optional) table parser. + """ + + def __init__(self, + database: str, + cluster: str, + schema: str, + name: str) -> None: + self._data = {'db': database, 'cluster': cluster, 'schema': schema, 'tbl': name} + + @property + def key(self) -> str: + return TableMetadata.TABLE_KEY_FORMAT.format(**self._data) + + +class RedashDashboardExtractor(Extractor): + """ + An extractor for retrieving dashboards and associated queries + (and possibly tables) from Redash. + + There are five configuration values: + + - `redash_base_url`: (e.g., `https://redash.example.com`) Base URL for the user-facing + Redash application + - `api_base_url`: (e.g., `https://redash.example.com/api`) Base URL for the API + - `api_key`: Redash API key + - (optional) `cluster`: A cluster name for this Redash instance (defaults to `prod`) + - (optional) `table_parser`: A function `(RedashVisualizationWidget) -> List[TableRelationData]`. + Given a `RedashVisualizationWidget`, this should return a list of potentially related tables + in Amundsen. Any table returned that exists in Amundsen will be linked to the dashboard. + Any table that does not exist will be ignored. + """ + + REDASH_BASE_URL_KEY = 'redash_base_url' + API_BASE_URL_KEY = 'api_base_url' + API_KEY_KEY = 'api_key' + CLUSTER_KEY = 'cluster' # optional config + TABLE_PARSER_KEY = 'table_parser' # optional config + REDASH_VERSION = 'redash_version' # optional config + + DEFAULT_CLUSTER = 'prod' + DEFAULT_VERSION = 9 + + PRODUCT = 'redash' + DASHBOARD_GROUP_ID = 'redash' + DASHBOARD_GROUP_NAME = 'Redash' + + def init(self, conf: ConfigTree) -> None: + + # required configuration + self._redash_base_url = conf.get_string(RedashDashboardExtractor.REDASH_BASE_URL_KEY) + self._api_base_url = conf.get_string(RedashDashboardExtractor.API_BASE_URL_KEY) + self._api_key = conf.get_string(RedashDashboardExtractor.API_KEY_KEY) + + # optional configuration + self._cluster = conf.get_string( + RedashDashboardExtractor.CLUSTER_KEY, RedashDashboardExtractor.DEFAULT_CLUSTER + ) + self._redash_version = conf.get_int( + RedashDashboardExtractor.REDASH_VERSION, RedashDashboardExtractor.DEFAULT_VERSION + ) + + self._parse_tables = None + tbl_parser_path = conf.get_string(RedashDashboardExtractor.TABLE_PARSER_KEY) + if tbl_parser_path: + module_name, fn_name = tbl_parser_path.rsplit('.', 1) + mod = importlib.import_module(module_name) + self._parse_tables = getattr(mod, fn_name) + + self._extractor = self._build_extractor() + self._transformer = self._build_transformer() + self._extract_iter: Optional[Iterator[Any]] = None + + def _is_published_dashboard(self, record: Dict[str, Any]) -> bool: + return not (record['is_archived'] or record['is_draft']) + + def _get_extract_iter(self) -> Iterator[Any]: + + while True: + record = self._extractor.extract() + if not record: + break # the end. + + record = next(self._transformer.transform(record=record), None) + + if not self._is_published_dashboard(record): + continue # filter this one out + + identity_data = { + 'cluster': self._cluster, + 'product': RedashDashboardExtractor.PRODUCT, + 'dashboard_group_id': str(RedashDashboardExtractor.DASHBOARD_GROUP_ID), + 'dashboard_id': str(record['dashboard_id']), + } + + if self._redash_version >= 9: + dashboard_url = f'{self._redash_base_url}/dashboards/{record["dashboard_id"]}' + else: + dashboard_url = f'{self._redash_base_url}/dashboard/{record["slug"]}' + + dash_data = { + 'dashboard_group': + RedashDashboardExtractor.DASHBOARD_GROUP_NAME, + 'dashboard_group_url': + self._redash_base_url, + 'dashboard_name': + record['dashboard_name'], + 'dashboard_url': + dashboard_url, + 'created_timestamp': + record['created_timestamp'] + } + dash_data.update(identity_data) + + widgets = sort_widgets(record['widgets']) + text_widgets = get_text_widgets(widgets) + viz_widgets = get_visualization_widgets(widgets) + + # generate a description for this dashboard, since Redash does not have descriptions + dash_data['description'] = generate_dashboard_description(text_widgets, viz_widgets) + + yield DashboardMetadata(**dash_data) + + last_mod_data = {'last_modified_timestamp': record['last_modified_timestamp']} + last_mod_data.update(identity_data) + + yield DashboardLastModifiedTimestamp(**last_mod_data) + + owner_data = {'email': record['user']['email']} + owner_data.update(identity_data) + + yield DashboardOwner(**owner_data) + + table_keys = set() + + for viz in viz_widgets: + query_data = { + 'query_id': str(viz.query_id), + 'query_name': viz.query_name, + 'url': self._redash_base_url + viz.query_relative_url, + 'query_text': viz.raw_query + } + + query_data.update(identity_data) + yield DashboardQuery(**query_data) + + chart_data = { + 'query_id': str(viz.query_id), + 'chart_id': str(viz.visualization_id), + 'chart_name': viz.visualization_name, + 'chart_type': viz.visualization_type, + } + chart_data.update(identity_data) + yield DashboardChart(**chart_data) + + # if a table parser is provided, retrieve tables from this viz + if self._parse_tables: + for tbl in self._parse_tables(viz): + table_keys.add(tbl.key) + + if len(table_keys) > 0: + yield DashboardTable(table_ids=list(table_keys), **identity_data) + + def extract(self) -> Any: + + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _build_restapi_query(self) -> RestApiQuery: + + dashes_query = RedashPaginatedRestApiQuery( + query_to_join=EmptyRestApiQuerySeed(), + url=f'{self._api_base_url}/dashboards', + params=self._get_default_api_query_params(), + json_path='results[*].[id,name,slug,created_at,updated_at,is_archived,is_draft,user]', + field_names=[ + 'dashboard_id', 'dashboard_name', 'slug', 'created_timestamp', + 'last_modified_timestamp', 'is_archived', 'is_draft', 'user' + ], + skip_no_result=True + ) + + if self._redash_version >= 9: + dashboard_url = f'{self._api_base_url}/dashboards/{{dashboard_id}}' + else: + dashboard_url = f'{self._api_base_url}/dashboards/{{slug}}' + + return RestApiQuery( + query_to_join=dashes_query, + url=dashboard_url, + params=self._get_default_api_query_params(), + json_path='widgets', + field_names=['widgets'], + skip_no_result=True + ) + + def _get_default_api_query_params(self) -> Dict[str, Any]: + + return {'headers': get_auth_headers(self._api_key)} + + def _build_extractor(self) -> RestAPIExtractor: + + extractor = RestAPIExtractor() + rest_api_extractor_conf = ConfigFactory.from_dict({ + REST_API_QUERY: self._build_restapi_query() + }) + extractor.init(rest_api_extractor_conf) + return extractor + + def _build_transformer(self) -> ChainedTransformer: + + transformers = [] + + # transform timestamps from ISO to unix epoch + ts_transformer_1 = TimestampStringToEpoch() + ts_transformer_1.init(ConfigFactory.from_dict({ + TS_FIELD_NAME: 'created_timestamp', + })) + transformers.append(ts_transformer_1) + + ts_transformer_2 = TimestampStringToEpoch() + ts_transformer_2.init(ConfigFactory.from_dict({ + TS_FIELD_NAME: 'last_modified_timestamp', + })) + transformers.append(ts_transformer_2) + + return ChainedTransformer(transformers=transformers) + + def get_scope(self) -> str: + + return 'extractor.redash_dashboard' diff --git a/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_utils.py b/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_utils.py new file mode 100644 index 0000000000..11a5f480bf --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/redash/redash_dashboard_utils.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterable, List, Tuple, +) + +from databuilder.rest_api.rest_api_query import RestApiQuery + + +class RedashVisualizationWidget: + """ + A visualization widget in a Redash dashboard. + These are mapped 1:1 with queries, and can be of various types, e.g.: + CHART, TABLE, PIVOT, etc. + The query name acts like a title for the widget on the dashboard. + """ + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + @property + def raw_query(self) -> str: + return self._data['visualization']['query']['query'] + + @property + def data_source_id(self) -> int: + return self._data['visualization']['query']['data_source_id'] + + @property + def query_id(self) -> int: + return self._data['visualization']['query']['id'] + + @property + def query_relative_url(self) -> str: + return f'/queries/{self.query_id}' + + @property + def query_name(self) -> str: + return self._data['visualization']['query']['name'] + + @property + def visualization_id(self) -> int: + return self._data['visualization']['id'] + + @property + def visualization_name(self) -> str: + return self._data['visualization']['name'] + + @property + def visualization_type(self) -> str: + return self._data['visualization']['type'] + + +class RedashTextWidget: + """ + A textbox in a Redash dashboad. + It pretty much just contains a single text property (Markdown). + """ + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + @property + def text(self) -> str: + return self._data['text'] + + +class RedashPaginatedRestApiQuery(RestApiQuery): + """ + Paginated Redash API queries + """ + + def __init__(self, **kwargs: Any) -> None: + super(RedashPaginatedRestApiQuery, self).__init__(**kwargs) + if 'params' not in self._params: + self._params['params'] = {} + self._params['params']['page'] = 1 + + def _total_records(self, res: Dict[str, Any]) -> int: + return res['count'] + + def _max_record_on_page(self, res: Dict[str, Any]) -> int: + return res['page_size'] * res['page'] + + def _next_page(self, res: Dict[str, Any]) -> int: + return res['page'] + 1 + + def _post_process(self, response: Any) -> None: + parsed = response.json() + + if self._max_record_on_page(parsed) >= self._total_records(parsed): + self._more_pages = False + else: + self._params['params']['page'] = self._next_page(parsed) + self._more_pages = True + + +def sort_widgets(widgets: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Sort raw widget data (as returned from the API) according to the position + of the widgets in the dashboard (top to bottom, left to right) + Redash does not return widgets in order of their position, + so we do this to ensure that we look at widgets in a sensible order. + """ + + def row_and_col(widget: Dict[str, Any]) -> Tuple[Any, Any]: + # these entities usually but not always have explicit rows and cols + pos = widget['options'].get('position', {}) + return (pos.get('row', 0), pos.get('col', 0)) + + return sorted(widgets, key=row_and_col) + + +def get_text_widgets(widgets: Iterable[Dict[str, Any]]) -> List[RedashTextWidget]: + """ + From the raw set of widget data returned from the API, filter down + to text widgets and return them as a list of `RedashTextWidget` + """ + + return [RedashTextWidget(widget) for widget in widgets + if 'text' in widget and 'visualization' not in widget] + + +def get_visualization_widgets(widgets: Iterable[Dict[str, Any]]) -> List[RedashVisualizationWidget]: + """ + From the raw set of widget data returned from the API, filter down + to visualization widgets and return them as a list of `RedashVisualizationWidget` + """ + + return [RedashVisualizationWidget(widget) for widget in widgets + if 'visualization' in widget] + + +def get_auth_headers(api_key: str) -> Dict[str, str]: + return {'Authorization': f'Key {api_key}'} + + +def generate_dashboard_description(text_widgets: List[RedashTextWidget], + viz_widgets: List[RedashVisualizationWidget]) -> str: + """ + Redash doesn't have dashboard descriptions, so we'll make our own. + If there exist any text widgets, concatenate them, + and use this text as the description for this dashboard. + If not, put together a list of query names. + If all else fails, this looks like an empty dashboard. + """ + + if len(text_widgets) > 0: + return '\n\n'.join([w.text for w in text_widgets]) + elif len(viz_widgets) > 0: + query_list = '\n'.join(set([f'- {v.query_name}' for v in set(viz_widgets)])) + return 'A dashboard containing the following queries:\n\n' + query_list + + return 'This dashboard appears to be empty!' diff --git a/databuilder/databuilder/extractor/dashboard/tableau/__init__.py b/databuilder/databuilder/extractor/dashboard/tableau/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_constants.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_constants.py new file mode 100644 index 0000000000..fe6fe9d010 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_constants.py @@ -0,0 +1,16 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +API_VERSION = 'api_version' +API_BASE_URL = 'api_base_url' +TABLEAU_BASE_URL = 'tableau_base_url' +SITE_NAME = 'site_name' +TABLEAU_ACCESS_TOKEN_NAME = 'tableau_personal_access_token_name' +TABLEAU_ACCESS_TOKEN_SECRET = 'tableau_personal_access_token_secret' +EXCLUDED_PROJECTS = 'excluded_projects' +EXTERNAL_CLUSTER_NAME = 'external_cluster_name' +EXTERNAL_SCHEMA_NAME = 'external_schema_name' +EXTERNAL_TABLE_TYPES = 'external_table_types' +CLUSTER = 'cluster' +DATABASE = 'database' +VERIFY_REQUEST = 'verify_request' diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_extractor.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_extractor.py new file mode 100644 index 0000000000..44fbb5645e --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_extractor.py @@ -0,0 +1,138 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, List, +) + +from pyhocon import ConfigFactory, ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardUtils, TableauGraphQLApiExtractor, +) +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME, TimestampStringToEpoch + +LOGGER = logging.getLogger(__name__) + + +class TableauGraphQLApiMetadataExtractor(TableauGraphQLApiExtractor): + """ + Implements the extraction-time logic for parsing the GraphQL result and transforming into a dict + that fills the DashboardMetadata model. Allows workbooks to be exlcuded based on their project. + """ + + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + SITE_NAME = const.SITE_NAME + TABLEAU_BASE_URL = const.TABLEAU_BASE_URL + + def execute(self) -> Iterator[Dict[str, Any]]: + response = self.execute_query() + + workbooks_data = [workbook for workbook in response['workbooks'] + if workbook['projectName'] not in + self._conf.get_list(TableauGraphQLApiMetadataExtractor.EXCLUDED_PROJECTS, [])] + base_url = self._conf.get(TableauGraphQLApiMetadataExtractor.TABLEAU_BASE_URL) + site_name = self._conf.get_string(TableauGraphQLApiMetadataExtractor.SITE_NAME, '') + site_url_path = '' + if site_name != '': + site_url_path = f'/site/{site_name}' + for workbook in workbooks_data: + if None in (workbook['projectName'], workbook['name']): + LOGGER.warning(f'Ignoring workbook (ID:{workbook["vizportalUrlId"]}) ' + + f'in project (ID:{workbook["projectVizportalUrlId"]}) because of a lack of permission') + continue + data = { + 'dashboard_group': workbook['projectName'], + 'dashboard_name': TableauDashboardUtils.sanitize_workbook_name(workbook['name']), + 'description': workbook.get('description', ''), + 'created_timestamp': workbook['createdAt'], + 'dashboard_group_url': f'{base_url}/#{site_url_path}/projects/{workbook["projectVizportalUrlId"]}', + 'dashboard_url': f'{base_url}/#{site_url_path}/workbooks/{workbook["vizportalUrlId"]}/views', + 'cluster': self._conf.get_string(TableauGraphQLApiMetadataExtractor.CLUSTER) + } + yield data + + +class TableauDashboardExtractor(Extractor): + """ + Extracts core metadata about Tableau "dashboards". + For the purposes of this extractor, Tableau "workbooks" are mapped to Amundsen dashboards, and the + top-level project in which these workbooks preside is the dashboard group. The metadata it gathers is: + Dashboard name (Workbook name) + Dashboard description (Workbook description) + Dashboard creation timestamp (Workbook creationstamp) + Dashboard group name (Workbook top-level folder name) + Uses the Metadata API: https://help.tableau.com/current/api/metadata_api/en-us/index.html + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + SITE_NAME = const.SITE_NAME + TABLEAU_BASE_URL = const.TABLEAU_BASE_URL + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self.query = """query { + workbooks { + id + name + createdAt + description + projectName + projectVizportalUrlId + vizportalUrlId + } + }""" + + self._extractor = self._build_extractor() + + transformers: List[Transformer] = [] + timestamp_str_to_epoch_transformer = TimestampStringToEpoch() + timestamp_str_to_epoch_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, timestamp_str_to_epoch_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({FIELD_NAME: 'created_timestamp', }))) + transformers.append(timestamp_str_to_epoch_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_metadata.DashboardMetadata'}))) + transformers.append(dict_to_model_transformer) + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return next(self._transformer.transform(record=record), None) + + def get_scope(self) -> str: + return 'extractor.tableau_dashboard_metadata' + + def _build_extractor(self) -> TableauGraphQLApiMetadataExtractor: + """ + Builds a TableauGraphQLApiMetadataExtractor. All data required can be retrieved with a single GraphQL call. + :return: A TableauGraphQLApiMetadataExtractor that provides core dashboard metadata. + """ + extractor = TableauGraphQLApiMetadataExtractor() + tableau_extractor_conf = Scoped.get_scoped_conf(self._conf, extractor.get_scope()) \ + .with_fallback(self._conf) \ + .with_fallback(ConfigFactory.from_dict({TableauGraphQLApiExtractor.QUERY: self.query, + STATIC_RECORD_DICT: {'product': 'tableau'}})) + extractor.init(conf=tableau_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_last_modified_extractor.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_last_modified_extractor.py new file mode 100644 index 0000000000..fc1a0fa84e --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_last_modified_extractor.py @@ -0,0 +1,129 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, List, +) + +from pyhocon import ConfigFactory, ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardUtils, TableauGraphQLApiExtractor, +) +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.transformer.base_transformer import ChainedTransformer, Transformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel +from databuilder.transformer.timestamp_string_to_epoch import FIELD_NAME, TimestampStringToEpoch + +LOGGER = logging.getLogger(__name__) + + +class TableauGraphQLApiLastModifiedExtractor(TableauGraphQLApiExtractor): + """ + Implements the extraction-time logic for parsing the GraphQL result and transforming into a dict + that fills the DashboardLastModifiedTimestamp model. Allows workbooks to be exlcuded based on their project. + """ + + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + + def execute(self) -> Iterator[Dict[str, Any]]: + response = self.execute_query() + + workbooks_data = [workbook for workbook in response['workbooks'] + if workbook['projectName'] not in + self._conf.get_list(TableauGraphQLApiLastModifiedExtractor.EXCLUDED_PROJECTS, [])] + + for workbook in workbooks_data: + if None in (workbook['projectName'], workbook['name']): + LOGGER.warning(f'Ignoring workbook (ID:{workbook["vizportalUrlId"]}) ' + + f'in project (ID:{workbook["projectVizportalUrlId"]}) because of a lack of permission') + continue + data = { + 'dashboard_group_id': workbook['projectName'], + 'dashboard_id': TableauDashboardUtils.sanitize_workbook_name(workbook['name']), + 'last_modified_timestamp': workbook['updatedAt'], + 'cluster': self._conf.get_string(TableauGraphQLApiLastModifiedExtractor.CLUSTER) + } + yield data + + +class TableauDashboardLastModifiedExtractor(Extractor): + """ + Extracts metadata about the time of last update for Tableau dashboards. + For the purposes of this extractor, Tableau "workbooks" are mapped to Amundsen dashboards, and the + top-level project in which these workbooks preside is the dashboard group. The metadata it gathers is: + Dashboard last modified timestamp (Workbook last modified timestamp) + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + SITE_NAME = const.SITE_NAME + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self.query = """query { + workbooks { + id + name + projectName + updatedAt + projectVizportalUrlId + vizportalUrlId + } + }""" + + self._extractor = self._build_extractor() + + transformers: List[Transformer] = [] + timestamp_str_to_epoch_transformer = TimestampStringToEpoch() + timestamp_str_to_epoch_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, timestamp_str_to_epoch_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict({FIELD_NAME: 'last_modified_timestamp', }))) + transformers.append(timestamp_str_to_epoch_transformer) + + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: + 'databuilder.models.dashboard.dashboard_last_modified.DashboardLastModifiedTimestamp'}))) + transformers.append(dict_to_model_transformer) + + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return next(self._transformer.transform(record=record), None) + + def get_scope(self) -> str: + return 'extractor.tableau_dashboard_last_modified' + + def _build_extractor(self) -> TableauGraphQLApiLastModifiedExtractor: + """ + Builds a TableauGraphQLApiExtractor. All data required can be retrieved with a single GraphQL call. + :return: A TableauGraphQLApiLastModifiedExtractor that provides dashboard update metadata. + """ + extractor = TableauGraphQLApiLastModifiedExtractor() + tableau_extractor_conf = \ + Scoped.get_scoped_conf(self._conf, extractor.get_scope())\ + .with_fallback(self._conf)\ + .with_fallback(ConfigFactory.from_dict({TableauGraphQLApiExtractor.QUERY: self.query, + STATIC_RECORD_DICT: {'product': 'tableau'} + } + ) + ) + extractor.init(conf=tableau_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_query_extractor.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_query_extractor.py new file mode 100644 index 0000000000..b30530b9ca --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_query_extractor.py @@ -0,0 +1,120 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, +) + +from pyhocon import ConfigFactory, ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardUtils, TableauGraphQLApiExtractor, +) +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.transformer.base_transformer import ChainedTransformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel + +LOGGER = logging.getLogger(__name__) + + +class TableauGraphQLApiQueryExtractor(TableauGraphQLApiExtractor): + """ + Implements the extraction-time logic for parsing the GraphQL result and transforming into a dict + that fills the DashboardQuery model. Allows workbooks to be exlcuded based on their project. + """ + + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + + def execute(self) -> Iterator[Dict[str, Any]]: + response = self.execute_query() + + for query in response['customSQLTables']: + for workbook in query['downstreamWorkbooks']: + if workbook['projectName'] not in \ + self._conf.get_list(TableauGraphQLApiQueryExtractor.EXCLUDED_PROJECTS, []): + data = { + 'dashboard_group_id': workbook['projectName'], + 'dashboard_id': TableauDashboardUtils.sanitize_workbook_name(workbook['name']), + 'query_name': query['name'], + 'query_id': query['id'], + 'query_text': query['query'], + 'cluster': self._conf.get_string(TableauGraphQLApiQueryExtractor.CLUSTER) + } + yield data + + +class TableauDashboardQueryExtractor(Extractor): + """ + Extracts metadata about the queries associated with Tableau workbooks. + In terms of Tableau's Metadata API, these queries are called "custom SQL tables". + However, not every workbook uses custom SQL queries, and most are built with a mixture of using the + datasource fields directly and various "calculated" columns. + This extractor iterates through one query at a time, yielding a new relationship for every downstream + workbook that uses the query. + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + SITE_NAME = const.SITE_NAME + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self.query = """query { + customSQLTables { + id + name + query + downstreamWorkbooks { + name + projectName + } + } + }""" + + self._extractor = self._build_extractor() + + transformers = [] + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_query.DashboardQuery'}))) + transformers.append(dict_to_model_transformer) + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return next(self._transformer.transform(record=record), None) + + def get_scope(self) -> str: + return 'extractor.tableau_dashboard_query' + + def _build_extractor(self) -> TableauGraphQLApiQueryExtractor: + """ + Builds a TableauGraphQLApiQueryExtractor. All data required can be retrieved with a single GraphQL call. + :return: A TableauGraphQLApiQueryExtractor that provides dashboard query metadata. + """ + extractor = TableauGraphQLApiQueryExtractor() + tableau_extractor_conf = \ + Scoped.get_scoped_conf(self._conf, extractor.get_scope())\ + .with_fallback(self._conf)\ + .with_fallback(ConfigFactory.from_dict({TableauGraphQLApiExtractor.QUERY: self.query, + STATIC_RECORD_DICT: {'product': 'tableau'} + } + ) + ) + extractor.init(conf=tableau_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_table_extractor.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_table_extractor.py new file mode 100644 index 0000000000..2f4a358665 --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_table_extractor.py @@ -0,0 +1,174 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, +) + +from pyhocon import ConfigFactory, ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardUtils, TableauGraphQLApiExtractor, +) +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT +from databuilder.models.table_metadata import TableMetadata +from databuilder.transformer.base_transformer import ChainedTransformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel + +LOGGER = logging.getLogger(__name__) + + +class TableauGraphQLDashboardTableExtractor(TableauGraphQLApiExtractor): + """ + Implements the extraction-time logic for parsing the GraphQL result and transforming into a dict + that fills the DashboardTable model. Allows workbooks to be exlcuded based on their project. + """ + + CLUSTER = const.CLUSTER + DATABASE = const.DATABASE + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + EXTERNAL_CLUSTER_NAME = const.EXTERNAL_CLUSTER_NAME + + def execute(self) -> Iterator[Dict[str, Any]]: + response = self.execute_query() + + workbooks_data = [workbook for workbook in response['workbooks'] + if workbook['projectName'] not in + self._conf.get_list(TableauGraphQLDashboardTableExtractor.EXCLUDED_PROJECTS, [])] + + for workbook in workbooks_data: + if None in (workbook['projectName'], workbook['name']): + LOGGER.warning(f'Ignoring workbook (ID:{workbook["vizportalUrlId"]}) ' + + f'in project (ID:{workbook["projectVizportalUrlId"]}) because of a lack of permission') + continue + data = { + 'dashboard_group_id': workbook['projectName'], + 'dashboard_id': TableauDashboardUtils.sanitize_workbook_name(workbook['name']), + 'cluster': self._conf.get_string(TableauGraphQLDashboardTableExtractor.CLUSTER), + 'table_ids': [] + } + + for table in workbook['upstreamTables']: + if table['name'] is None: + LOGGER.warning(f'Ignoring a table in workbook (ID:{workbook["name"]}) ' + + f'in project (ID:{workbook["projectName"]}) because of a lack of permission') + continue + # external tables have no schema, so they must be parsed differently + # see TableauExternalTableExtractor for more specifics + if table['schema'] != '': + cluster = self._conf.get_string(TableauGraphQLDashboardTableExtractor.CLUSTER) + database = self._conf.get_string(TableauGraphQLDashboardTableExtractor.DATABASE) + + # Tableau sometimes incorrectly assigns the "schema" value + # based on how the datasource connection is used in a workbook. + # It will hide the real schema inside the table name, like "real_schema.real_table", + # and set the "schema" value to "wrong_schema". In every case discovered so far, the schema + # key is incorrect, so the "inner" schema from the table name is used instead. + if '.' in table['name']: + parts = table['name'].split('.') + if len(parts) == 2: + schema, name = parts + else: + database = '.'.join(parts[:-2]) + schema, name = parts[-2:] + else: + schema, name = table['schema'], table['name'] + schema = TableauDashboardUtils.sanitize_schema_name(schema) + name = TableauDashboardUtils.sanitize_table_name(name) + else: + cluster = self._conf.get_string(TableauGraphQLDashboardTableExtractor.EXTERNAL_CLUSTER_NAME) + database = TableauDashboardUtils.sanitize_database_name( + table['database']['connectionType'] + ) + schema = TableauDashboardUtils.sanitize_schema_name(table['database']['name']) + name = TableauDashboardUtils.sanitize_table_name(table['name']) + + table_id = TableMetadata.TABLE_KEY_FORMAT.format( + db=database, + cluster=cluster, + schema=schema, + tbl=name, + ) + data['table_ids'].append(table_id) + + yield data + + +class TableauDashboardTableExtractor(Extractor): + """ + Extracts metadata about the tables associated with Tableau workbooks. + It can handle both "regular" database tables as well as "external" tables + (see TableauExternalTableExtractor for more info on external tables). + Assumes that all the nodes for both the dashboards and the tables have already been created. + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + CLUSTER = const.CLUSTER + DATABASE = const.DATABASE + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + EXTERNAL_CLUSTER_NAME = const.EXTERNAL_CLUSTER_NAME + SITE_NAME = const.SITE_NAME + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self.query = """query { + workbooks { + name + projectName + projectVizportalUrlId + vizportalUrlId + upstreamTables { + name + schema + database { + name + connectionType + } + } + } + }""" + self._extractor = self._build_extractor() + + transformers = [] + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.dashboard.dashboard_table.DashboardTable'}))) + transformers.append(dict_to_model_transformer) + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return next(self._transformer.transform(record=record), None) + + def get_scope(self) -> str: + return 'extractor.tableau_dashboard_table' + + def _build_extractor(self) -> TableauGraphQLDashboardTableExtractor: + """ + Builds a TableauGraphQLDashboardTableExtractor. All data required can be retrieved with a single GraphQL call. + :return: A TableauGraphQLDashboardTableExtractor that creates dashboard <> table relationships. + """ + extractor = TableauGraphQLDashboardTableExtractor() + tableau_extractor_conf = \ + Scoped.get_scoped_conf(self._conf, extractor.get_scope())\ + .with_fallback(self._conf)\ + .with_fallback(ConfigFactory.from_dict({TableauGraphQLApiExtractor.QUERY: self.query, + STATIC_RECORD_DICT: {'product': 'tableau'} + } + ) + ) + extractor.init(conf=tableau_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_utils.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_utils.py new file mode 100644 index 0000000000..b46f9b490e --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_dashboard_utils.py @@ -0,0 +1,199 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import json +import re +from typing import ( + Any, Dict, Iterator, Optional, +) + +import requests +from pyhocon import ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.restapi.rest_api_extractor import STATIC_RECORD_DICT + + +class TableauDashboardUtils: + """ + Provides various utility functions specifc to the Tableau dashboard extractors. + """ + + @staticmethod + def sanitize_schema_name(schema_name: str) -> str: + """ + Sanitizes a given string so that it can safely be used as a table's schema. + Sanitization behaves as follows: + - all spaces and periods are replaced by underscores + - any [], (), -, &, and ? characters are deleted + """ + # this indentation looks a little odd, but otherwise the linter complains + return re.sub(r' ', '_', + re.sub(r'\.', '_', + re.sub(r'(\[|\]|\(|\)|\-|\&|\?)', '', schema_name))) + + @staticmethod + def sanitize_database_name(database_name: str) -> str: + """ + Sanitizes a given string so that it can safely be used as a table's database. + Sanitization behaves as follows: + - all hyphens are deleted + """ + return re.sub(r'-', '', database_name) + + @staticmethod + def sanitize_table_name(table_name: str) -> str: + """ + Sanitizes a given string so that it can safely be used as a table name. + Replicates the current behavior of sanitize_workbook_name, but this is purely coincidental. + As more breaking characters/patterns are found, each method should be updated to reflect the specifics. + Sanitization behaves as follows: + - all forward slashes and single quotes characters are deleted + """ + return re.sub(r'(\/|\')', '', table_name) + + @staticmethod + def sanitize_workbook_name(workbook_name: str) -> str: + """ + Sanitizes a given string so that it can safely be used as a workbook ID. + Mimics the current behavior of sanitize_table_name for now, but is purely coincidental. + As more breaking characters/patterns are found, each method should be updated to reflect the specifics. + Sanitization behaves as follows: + - all forward slashes and single quotes characters are deleted + """ + return re.sub(r'(\/|\')', '', workbook_name) + + +class TableauGraphQLApiExtractor(Extractor): + """ + Base class for querying the Tableau Metdata API, which uses a GraphQL schema. + """ + + API_BASE_URL = const.API_BASE_URL + QUERY = 'query' + QUERY_VARIABLES = 'query_variables' + VERIFY_REQUEST = 'verify_request' + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self._auth_token = TableauDashboardAuth(self._conf).token + self._query = self._conf.get(TableauGraphQLApiExtractor.QUERY) + self._iterator: Optional[Iterator[Dict[str, Any]]] = None + self._static_dict = conf.get(STATIC_RECORD_DICT, dict()) + self._metadata_url = '{api_base_url}/api/metadata/graphql'.format( + api_base_url=self._conf.get_string(TableauGraphQLApiExtractor.API_BASE_URL) + ) + self._query_variables = self._conf.get(TableauGraphQLApiExtractor.QUERY_VARIABLES, {}) + self._verify_request = self._conf.get(TableauGraphQLApiExtractor.VERIFY_REQUEST, None) + + def execute_query(self) -> Dict[str, Any]: + """ + Executes the extractor's given query and returns the data from the results. + """ + query_payload = json.dumps({ + 'query': self._query, + 'variables': self._query_variables + }) + headers = { + 'Content-Type': 'application/json', + 'X-Tableau-Auth': self._auth_token + } + params: Dict[str, Any] = { + 'headers': headers + } + if self._verify_request is not None: + params['verify'] = self._verify_request + + response = requests.post(url=self._metadata_url, data=query_payload, **params) + return response.json()['data'] + + @abc.abstractmethod + def execute(self) -> Iterator[Dict[str, Any]]: + """ + Must be overriden by any extractor using this class. This should parse the result and yield each entity's + metadata one by one. + """ + pass + + def extract(self) -> Any: + """ + Fetch one result at a time from the generator created by self.execute(), updating using the + static record values if needed. + """ + if not self._iterator: + self._iterator = self.execute() + + try: + record = next(self._iterator) + except StopIteration: + return None + + if self._static_dict: + record.update(self._static_dict) + + return record + + +class TableauDashboardAuth: + """ + Attempts to authenticate agains the Tableau REST API using the provided personal access token credentials. + When successful, it will create a valid token that must be used on all subsequent requests. + https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_concepts_auth.htm + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + SITE_NAME = const.SITE_NAME + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def __init__(self, conf: ConfigTree) -> None: + self._token: Optional[str] = None + self._conf = conf + self._access_token_name = self._conf.get_string(TableauDashboardAuth.TABLEAU_ACCESS_TOKEN_NAME) + self._access_token_secret = self._conf.get_string(TableauDashboardAuth.TABLEAU_ACCESS_TOKEN_SECRET) + self._api_version = self._conf.get_string(TableauDashboardAuth.API_VERSION) + self._site_name = self._conf.get_string(TableauDashboardAuth.SITE_NAME, '') + self._api_base_url = self._conf.get_string(TableauDashboardAuth.API_BASE_URL) + self._verify_request = self._conf.get(TableauDashboardAuth.VERIFY_REQUEST, None) + + @property + def token(self) -> Optional[str]: + if not self._token: + self._token = self._authenticate() + return self._token + + def _authenticate(self) -> str: + """ + Queries the auth/signin endpoint for the given Tableau instance using a personal access token. + The API version differs with your version of Tableau. + See https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_concepts_versions.htm + for details or ask your Tableau server administrator. + """ + self._auth_url = f"{self._api_base_url}/api/{self._api_version}/auth/signin" + + payload = json.dumps({ + 'credentials': { + 'personalAccessTokenName': self._access_token_name, + 'personalAccessTokenSecret': self._access_token_secret, + 'site': { + 'contentUrl': self._site_name + } + } + }) + headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json' + } + # verify = False is needed bypass occasional (valid) self-signed cert errors. TODO: actually fix it!! + params: Dict[str, Any] = { + 'headers': headers + } + if self._verify_request is not None: + params['verify'] = self._verify_request + + response_json = requests.post(url=self._auth_url, data=payload, **params).json() + return response_json['credentials']['token'] diff --git a/databuilder/databuilder/extractor/dashboard/tableau/tableau_external_table_extractor.py b/databuilder/databuilder/extractor/dashboard/tableau/tableau_external_table_extractor.py new file mode 100644 index 0000000000..744c7eb97e --- /dev/null +++ b/databuilder/databuilder/extractor/dashboard/tableau/tableau_external_table_extractor.py @@ -0,0 +1,147 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, +) + +from pyhocon import ConfigFactory, ConfigTree + +import databuilder.extractor.dashboard.tableau.tableau_dashboard_constants as const +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardUtils, TableauGraphQLApiExtractor, +) +from databuilder.transformer.base_transformer import ChainedTransformer +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel + +LOGGER = logging.getLogger(__name__) + + +class TableauGraphQLExternalTableExtractor(TableauGraphQLApiExtractor): + """ + Implements the extraction-time logic for parsing the GraphQL result and transforming into a dict + that fills the TableMetadata model. + """ + + EXTERNAL_CLUSTER_NAME = const.EXTERNAL_CLUSTER_NAME + EXTERNAL_SCHEMA_NAME = const.EXTERNAL_SCHEMA_NAME + + def execute(self) -> Iterator[Dict[str, Any]]: + response = self.execute_query() + + for table in response['databases']: + if table['connectionType'] in ['google-sheets', 'salesforce', 'excel-direct']: + for downstreamTable in table['tables']: + data = { + 'cluster': self._conf.get_string(TableauGraphQLExternalTableExtractor.EXTERNAL_CLUSTER_NAME), + 'database': TableauDashboardUtils.sanitize_database_name( + table['connectionType'] + ), + 'schema': TableauDashboardUtils.sanitize_schema_name(table['name']), + 'name': TableauDashboardUtils.sanitize_table_name(downstreamTable['name']), + 'description': table['description'] + } + yield data + else: + data = { + 'cluster': self._conf.get_string(TableauGraphQLExternalTableExtractor.EXTERNAL_CLUSTER_NAME), + 'database': TableauDashboardUtils.sanitize_database_name(table['connectionType']), + 'schema': self._conf.get_string(TableauGraphQLExternalTableExtractor.EXTERNAL_SCHEMA_NAME), + 'name': TableauDashboardUtils.sanitize_table_name(table['name']), + 'description': table['description'] + } + yield data + + +class TableauDashboardExternalTableExtractor(Extractor): + """ + Creates the "external" Tableau tables. + In this context, "external" tables are "tables" that are not from a typical database, and are loaded + using some other data format, like CSV files. + This extractor has been tested with the following types of external tables: + Excel spreadsheets + Text files (including CSV files) + Salesforce connections + Google Sheets connections + + Excel spreadsheets, Salesforce connections, and Google Sheets connections are all classified as + "databases" in terms of Tableau's Metadata API, with their "subsheets" forming their "tables" when + present. However, these tables are not assigned a schema, this extractor chooses to use the name + parent sheet as the schema, and assign a new table to each subsheet. The connection type is + always used as the database, and for text files, the schema is set using the EXTERNAL_SCHEMA_NAME + config option. Since these external tables are usually named for human consumption only and often + contain a wider range of characters, all inputs are transformed to remove any problematic + occurences before they are inserted: see the sanitize methods TableauDashboardUtils for specifics. + + A more concrete example: if one had a Google Sheet titled "Growth by Region & County" with 2 subsheets called + "FY19 Report" and "FY20 Report", two tables would be generated with the following keys: + googlesheets://external.growth_by_region_county/FY_19_Report + googlesheets://external.growth_by_region_county/FY_20_Report + """ + + API_BASE_URL = const.API_BASE_URL + API_VERSION = const.API_VERSION + CLUSTER = const.CLUSTER + EXCLUDED_PROJECTS = const.EXCLUDED_PROJECTS + EXTERNAL_CLUSTER_NAME = const.EXTERNAL_CLUSTER_NAME + EXTERNAL_SCHEMA_NAME = const.EXTERNAL_SCHEMA_NAME + EXTERNAL_TABLE_TYPES = const.EXTERNAL_TABLE_TYPES + SITE_NAME = const.SITE_NAME + TABLEAU_ACCESS_TOKEN_NAME = const.TABLEAU_ACCESS_TOKEN_NAME + TABLEAU_ACCESS_TOKEN_SECRET = const.TABLEAU_ACCESS_TOKEN_SECRET + VERIFY_REQUEST = const.VERIFY_REQUEST + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self.query = """query externalTables($externalTableTypes: [String]) { + databases (filter: {connectionTypeWithin: $externalTableTypes}) { + name + connectionType + description + tables { + name + } + } + }""" + self.query_variables = { + 'externalTableTypes': self._conf.get_list(TableauDashboardExternalTableExtractor.EXTERNAL_TABLE_TYPES)} + self._extractor = self._build_extractor() + + transformers = [] + dict_to_model_transformer = DictToModel() + dict_to_model_transformer.init( + conf=Scoped.get_scoped_conf(self._conf, dict_to_model_transformer.get_scope()).with_fallback( + ConfigFactory.from_dict( + {MODEL_CLASS: 'databuilder.models.table_metadata.TableMetadata'}))) + transformers.append(dict_to_model_transformer) + self._transformer = ChainedTransformer(transformers=transformers) + + def extract(self) -> Any: + record = self._extractor.extract() + if not record: + return None + + return self._transformer.transform(record=record) + + def get_scope(self) -> str: + return 'extractor.tableau_external_table' + + def _build_extractor(self) -> TableauGraphQLExternalTableExtractor: + """ + Builds a TableauGraphQLExternalTableExtractor. All data required can be retrieved with a single GraphQL call. + :return: A TableauGraphQLExternalTableExtractor that creates external table metadata entities. + """ + extractor = TableauGraphQLExternalTableExtractor() + + config_dict = { + TableauGraphQLApiExtractor.QUERY_VARIABLES: self.query_variables, + TableauGraphQLApiExtractor.QUERY: self.query} + tableau_extractor_conf = \ + Scoped.get_scoped_conf(self._conf, extractor.get_scope())\ + .with_fallback(self._conf)\ + .with_fallback(ConfigFactory.from_dict(config_dict)) + extractor.init(conf=tableau_extractor_conf) + return extractor diff --git a/databuilder/databuilder/extractor/db2_metadata_extractor.py b/databuilder/databuilder/extractor/db2_metadata_extractor.py new file mode 100644 index 0000000000..6ad8149fbc --- /dev/null +++ b/databuilder/databuilder/extractor/db2_metadata_extractor.py @@ -0,0 +1,130 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class Db2MetadataExtractor(Extractor): + """ + Extracts Db2 table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + # SELECT statement from Db2 SYSIBM to extract table and column metadata + SQL_STATEMENT = """ + SELECT + {cluster_source} as cluster, c.TABSCHEMA as schema, c.TABNAME as name, t.REMARKS as description, + c.COLNAME as col_name, + CASE WHEN c.TYPENAME='VARCHAR' OR c.TYPENAME='CHARACTER' THEN + TRIM (TRAILING FROM c.TYPENAME) concat '(' concat c.LENGTH concat ')' + WHEN c.TYPENAME='DECIMAL' THEN + TRIM (TRAILING FROM c.TYPENAME) concat '(' concat c.LENGTH concat ',' concat c.SCALE concat ')' + ELSE TRIM (TRAILING FROM c.TYPENAME) END as col_type, + c.REMARKS as col_description, c.COLNO as col_sort_order + FROM SYSCAT.COLUMNS c + INNER JOIN + SYSCAT.TABLES as t on c.TABSCHEMA=t.TABSCHEMA and c.TABNAME=t.TABNAME + {where_clause_suffix} + ORDER by cluster, schema, name, col_sort_order ; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', CLUSTER_KEY: DEFAULT_CLUSTER_NAME} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(Db2MetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(Db2MetadataExtractor.CLUSTER_KEY) + + cluster_source = f"'{self._cluster}'" + + self._database = conf.get_string(Db2MetadataExtractor.DATABASE_KEY, default='db2') + + self.sql_stmt = Db2MetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(Db2MetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + cluster_source=cluster_source + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info('SQL for Db2 metadata: %s', self.sql_stmt) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.db2_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order'])) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + last_row['description'], + columns) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/db_api_extractor.py b/databuilder/databuilder/extractor/db_api_extractor.py new file mode 100644 index 0000000000..33447277b1 --- /dev/null +++ b/databuilder/databuilder/extractor/db_api_extractor.py @@ -0,0 +1,83 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from typing import Any, Iterable + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor + +LOGGER = logging.getLogger(__name__) + + +class DBAPIExtractor(Extractor): + """ + Generic DB API extractor. + """ + CONNECTION_CONFIG_KEY = 'connection' + SQL_CONFIG_KEY = 'sql' + + def init(self, conf: ConfigTree) -> None: + """ + Receives a {Connection} object and {sql} to execute. + An optional model class can be passed, in which, sql result row + would be converted to a class instance and returned to calling + function + :param conf: + :return: + """ + self.conf = conf + self.connection: Any = conf.get(DBAPIExtractor.CONNECTION_CONFIG_KEY) + self.cursor = self.connection.cursor() + self.sql = conf.get(DBAPIExtractor.SQL_CONFIG_KEY) + + model_class = conf.get('model_class', None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + + self._iter = iter(self._execute_query()) + + def _execute_query(self) -> Iterable[Any]: + """ + Use cursor to execute the {sql} + :return: + """ + LOGGER.info('Executing query: \n%s', self.sql) + self.cursor.execute(self.sql) + return self.cursor.fetchall() + + def extract(self) -> Any: + """ + Fetch one sql result row, convert to {model_class} if specified before + returning. + :return: + """ + + try: + result = next(self._iter) + except StopIteration: + return None + + if hasattr(self, 'model_class'): + obj = self.model_class(*result[:len(result)]) + return obj + else: + return result + + def close(self) -> None: + """ + close cursor and connection handlers + :return: + """ + try: + self.cursor.close() + self.connection.close() + except Exception as e: + LOGGER.warning("Exception encountered while closing up connection handler!", e) + + def get_scope(self) -> str: + return 'extractor.dbapi' diff --git a/databuilder/databuilder/extractor/dbt_extractor.py b/databuilder/databuilder/extractor/dbt_extractor.py new file mode 100644 index 0000000000..929ab8af05 --- /dev/null +++ b/databuilder/databuilder/extractor/dbt_extractor.py @@ -0,0 +1,330 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import os +from enum import Enum +from typing import ( + Dict, Iterator, List, Optional, Tuple, Union, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.table_lineage import TableLineage +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.table_source import TableSource + +LOGGER = logging.getLogger(__name__) + + +DBT_CATALOG_REQD_KEYS = ['nodes'] +DBT_MANIFEST_REQD_KEYS = ['nodes', 'child_map'] +DBT_MODEL_TYPE = 'model' +DBT_MODEL_PREFIX = 'model.' +DBT_TEST_PREFIX = 'test.' + + +class DBT_TAG_AS(Enum): + BADGE = 'badge' + TAG = 'tag' + + +class DBT_MODEL_NAME_KEY(Enum): + ALIAS = 'alias' + NAME = 'name' + + +class InvalidDbtInputs(Exception): + pass + + +class DbtExtractor(Extractor): + """ + Extracts metadata from the dbt manifest.json and catalog.json files. + At least one of a manifest or a catalog (or both) must be provided. + The location of the file or a valid Python dictionary of the content + can be provided. + + Currently the following assets are extracted from these files: + + - Tables + - Columns + - Definitions + - Table lineage + - Tags (converted to Amundsen Badges) + + Additional metadagta exists and may be extracted in the future: + + - Run / test outcomes + - Freshness + - Hooks (as programatic description?) + - Analysis (as queries for a table??) + - Table / column level statistics + - Table comments (as programatic descriptoins) + """ + + CATALOG_JSON = "catalog_json" + MANIFEST_JSON = "manifest_json" + DATABASE_NAME = 'database_name' + + # Dbt Extract Options + EXTRACT_TABLES = 'extract_tables' + EXTRACT_DESCRIPTIONS = 'extract_descriptions' + EXTRACT_TAGS = 'extract_tags' + EXTRACT_LINEAGE = 'extract_lineage' + SOURCE_URL = 'source_url' # Base source code URL for the repo containing dbt workflows + IMPORT_TAGS_AS = 'import_tags_as' + SCHEMA_FILTER = 'schema_filter' # Only extract dbt models from this schema, defaults to all models + MODEL_NAME_KEY = 'model_name_key' # Whether to use the "name" or "alias" from dbt as the Amundsen name + + # Makes all db, schema, cluster and table names lowercase. This is done so that table metadata from dbt + # with the default key `Sample://Cluster/Schema/Table` match existing metadata that Amundsen has from + # the database, which may be `sample://cluster/schema/table`. + # Most databases that dbt integrates with either use lowercase by default in the information schema + # or the default Amundsen extractor applies a `lower(...)` function to the result (e.g. snowflake). + # However, Amundsen does not currently enforce a consistent convention and some databases do support + # upper and lowercase naming conventions (e.g. Redshift). It may be useful to set this False in the + # config if the table metadata keys in your database are not all lowercase and to then use a transformer to + # properly format the string value. + FORCE_TABLE_KEY_LOWER = 'force_table_key_lower' + + def init(self, conf: ConfigTree) -> None: + self._conf = conf + self._database_name = conf.get_string(DbtExtractor.DATABASE_NAME) + self._dbt_manifest = conf.get_string(DbtExtractor.MANIFEST_JSON) + self._dbt_catalog = conf.get_string(DbtExtractor.CATALOG_JSON) + # Extract options + self._extract_tables = conf.get_bool(DbtExtractor.EXTRACT_TABLES, True) + self._extract_descriptions = conf.get_bool(DbtExtractor.EXTRACT_DESCRIPTIONS, True) + self._extract_tags = conf.get_bool(DbtExtractor.EXTRACT_TAGS, True) + self._extract_lineage = conf.get_bool(DbtExtractor.EXTRACT_LINEAGE, True) + self._source_url = conf.get_string(DbtExtractor.SOURCE_URL, None) + self._force_table_key_lower = conf.get_bool(DbtExtractor.FORCE_TABLE_KEY_LOWER, True) + self._dbt_tag_as = DBT_TAG_AS(conf.get_string(DbtExtractor.IMPORT_TAGS_AS, DBT_TAG_AS.BADGE.value)) + self._schema_filter = conf.get_string(DbtExtractor.SCHEMA_FILTER, '') + self._model_name_key = DBT_MODEL_NAME_KEY( + conf.get_string(DbtExtractor.MODEL_NAME_KEY, DBT_MODEL_NAME_KEY.NAME.value)).value + self._clean_inputs() + self._extract_iter: Union[None, Iterator] = None + + def get_scope(self) -> str: + return "extractor.dbt" + + def _validate_catalog(self) -> None: + # Load the catalog file if needed and run basic validation on the content + try: + self._dbt_catalog = json.loads(self._dbt_catalog) + except Exception: + try: + with open(self._dbt_catalog, 'rb') as f: + self._dbt_catalog = json.loads(f.read().lower()) + except Exception as e: + raise InvalidDbtInputs( + 'Invalid content for a dbt catalog was provided. Must be a valid Python ' + 'dictionary or the location of a file. Error received: %s' % e + ) + for catalog_key in DBT_CATALOG_REQD_KEYS: + if catalog_key not in self._dbt_catalog: + raise InvalidDbtInputs( + "Dbt catalog file must contain keys: %s, found keys: %s" + % (DBT_CATALOG_REQD_KEYS, self._dbt_catalog.keys()) + ) + + def _validate_manifest(self) -> None: + # Load the manifest file if needed and run basic validation on the content + try: + self._dbt_manifest = json.loads(self._dbt_manifest) + except Exception: + try: + with open(self._dbt_manifest, 'rb') as f: + self._dbt_manifest = json.loads(f.read().lower()) + except Exception as e: + raise InvalidDbtInputs( + 'Invalid content for a dbt manifest was provided. Must be a valid Python ' + 'dictionary or the location of a file. Error received: %s' % e + ) + for manifest_key in DBT_MANIFEST_REQD_KEYS: + if manifest_key not in self._dbt_manifest: + raise InvalidDbtInputs( + "Dbt manifest file must contain keys: %s, found keys: %s" + % (DBT_MANIFEST_REQD_KEYS, self._dbt_manifest.keys()) + ) + + def _clean_inputs(self) -> None: + """ + Validates the dbt input to ensure that at least one of the inputs + (manifest.json or catalog.json) are provided. Once validated, the + inputs are sanitized to ensure that the `self._dbt_manifest` and + `self._dbt_catalog` are valid Python dictionaries. + """ + if self._database_name is None: + raise InvalidDbtInputs( + 'Must provide a database name that corresponds to this dbt catalog and manifest.' + ) + + if not self._dbt_manifest or not self._dbt_catalog: + raise InvalidDbtInputs( + 'Must provide a dbt manifest file and dbt catalog file.' + ) + + self._validate_catalog() + self._validate_manifest() + + def extract(self) -> Union[TableMetadata, None]: + """ + For every feature table from Feast, a multiple objets are extracted: + + 1. TableMetadata with feature table description + 2. Programmatic Description of the feature table, containing + metadata - date of creation and labels + 3. Programmatic Description with Batch Source specification + 4. (if applicable) Programmatic Description with Stream Source + specification + """ + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _default_sanitize(self, s: str) -> str: + """ + Default function that will be run to convert the value of a string to lowercase. + """ + if s and self._force_table_key_lower: + s = s.lower() + return s + + def _get_table_descriptions(self, manifest_content: Dict) -> Tuple[Optional[str], Optional[str]]: + """ + Gets a description and description source for a table. + """ + desc, desc_src = None, None + if self._extract_descriptions: + desc = manifest_content.get('description') + desc_src = 'dbt description' + return desc, desc_src + + def _get_table_tags_badges(self, manifest_content: Dict) -> Tuple[Optional[List[str]], Optional[List[str]]]: + """ + Gets tags or badges for a given table. At most one of these values will not be null. + """ + tags, tbl_badges = None, None + if self._extract_tags: + if self._dbt_tag_as == DBT_TAG_AS.BADGE: + tbl_badges = manifest_content.get('tags') + elif self._dbt_tag_as == DBT_TAG_AS.TAG: + tags = manifest_content.get('tags') + return tags, tbl_badges + + def _can_yield_schema(self, schema: str) -> bool: + """ + Whether or not the schema can be yielded based on the schema filter criteria. + """ + return (not self._schema_filter) or (self._schema_filter.lower() == schema.lower()) + + def _get_extract_iter(self) -> Iterator[Union[TableMetadata, BadgeMetadata, TableSource, TableLineage]]: + """ + Generates the extract iterator for all of the model types created by the dbt files. + """ + dbt_id_to_table_key = {} + for tbl_node, manifest_content in self._dbt_manifest['nodes'].items(): + + if manifest_content['resource_type'] == DBT_MODEL_TYPE and tbl_node in self._dbt_catalog['nodes']: + LOGGER.info( + 'Extracting dbt {}.{}'.format(manifest_content['schema'], manifest_content[self._model_name_key]) + ) + + catalog_content = self._dbt_catalog['nodes'][tbl_node] + + tbl_columns: List[ColumnMetadata] = self._get_column_values( + manifest_columns=manifest_content['columns'], catalog_columns=catalog_content['columns'] + ) + + desc, desc_src = self._get_table_descriptions(manifest_content) + tags, tbl_badges = self._get_table_tags_badges(manifest_content) + + tbl_metadata = TableMetadata( + database=self._default_sanitize(self._database_name), + # The dbt "database" is the cluster here + cluster=self._default_sanitize(manifest_content['database']), + schema=self._default_sanitize(manifest_content['schema']), + name=self._default_sanitize(manifest_content[self._model_name_key]), + is_view=catalog_content['metadata']['type'] == 'view', + columns=tbl_columns, + tags=tags, + description=desc, + description_source=desc_src + ) + # Keep track for Lineage + dbt_id_to_table_key[tbl_node] = tbl_metadata._get_table_key() + + # Optionally filter schemas in the output + yield_schema = self._can_yield_schema(manifest_content['schema']) + + if self._extract_tables and yield_schema: + yield tbl_metadata + + if self._extract_tags and tbl_badges and yield_schema: + yield BadgeMetadata(start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=tbl_metadata._get_table_key(), + badges=[Badge(badge, 'table') for badge in tbl_badges]) + + if self._source_url and yield_schema: + yield TableSource(db_name=tbl_metadata.database, + cluster=tbl_metadata.cluster, + schema=tbl_metadata.schema, + table_name=tbl_metadata.name, + source=os.path.join(self._source_url, manifest_content.get('original_file_path'))) + + if self._extract_lineage: + for upstream, downstreams in self._dbt_manifest['child_map'].items(): + if upstream not in dbt_id_to_table_key: + continue + valid_downstreams = [ + dbt_id_to_table_key[k] for k in downstreams + if k.startswith(DBT_MODEL_PREFIX) and dbt_id_to_table_key.get(k) + ] + if valid_downstreams: + yield TableLineage( + table_key=dbt_id_to_table_key[upstream], + downstream_deps=valid_downstreams + ) + + def _get_column_values(self, manifest_columns: Dict, catalog_columns: Dict) -> List[ColumnMetadata]: + """ + Iterates over the columns in the manifest file and creates a `ColumnMetadata` object + with the combined information from the manifest file as well as the catalog file. + + :params manifest_columns: A dictionary of values from the manifest.json, the keys + are column names and the values are column metadata + :params catalog_columns: A dictionary of values from the catalog.json, the keys + are column names and the values are column metadata + :returns: A list of `ColumnMetadata` in Amundsen. + """ + tbl_columns = [] + for catalog_col_name, catalog_col_content in catalog_columns.items(): + manifest_col_content = manifest_columns.get(catalog_col_name, {}) + if catalog_col_content: + col_desc = None + if self._extract_descriptions: + col_desc = manifest_col_content.get('description') + + # Only extract column-level tags IF converting to badges, Amundsen does not have column-level tags + badges = None + if self._extract_tags and self._dbt_tag_as == DBT_TAG_AS.BADGE: + badges = manifest_col_content.get('tags') + + col_metadata = ColumnMetadata( + name=self._default_sanitize(catalog_col_content['name']), + description=col_desc, + col_type=catalog_col_content['type'], + sort_order=catalog_col_content['index'], + badges=badges + ) + tbl_columns.append(col_metadata) + return tbl_columns diff --git a/databuilder/databuilder/extractor/delta_lake_metadata_extractor.py b/databuilder/databuilder/extractor/delta_lake_metadata_extractor.py new file mode 100644 index 0000000000..f638ba146a --- /dev/null +++ b/databuilder/databuilder/extractor/delta_lake_metadata_extractor.py @@ -0,0 +1,534 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import concurrent.futures +import logging +from collections import namedtuple +from datetime import datetime +from typing import ( # noqa: F401 + Any, Dict, Iterator, List, Optional, Tuple, Union, +) + +from pyhocon import ConfigFactory, ConfigTree # noqa: F401 +from pyspark.sql import SparkSession +from pyspark.sql.catalog import Table +from pyspark.sql.types import ( + ArrayType, MapType, StructField, StructType, +) +from pyspark.sql.utils import AnalysisException, ParseException + +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.table_metadata_constants import PARTITION_BADGE +from databuilder.models.table_last_updated import TableLastUpdated +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.watermark import Watermark + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +# TODO once column tags work properly, consider deprecating this for TableMetadata directly +class ScrapedColumnMetadata(object): + def __init__(self, name: str, data_type: str, description: Optional[str], sort_order: int, + badges: Union[List[str], None] = None): + self.name = name + self.data_type = data_type + self.description = description + self.sort_order = sort_order + self.is_partition = False + self.attributes: Dict[str, str] = {} + self.badges = badges + + def set_is_partition(self, is_partition: bool) -> None: + self.is_partition = is_partition + + def set_badges(self, badges: Union[List[str], None]) -> None: + self.badges = badges + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ScrapedColumnMetadata): + return False + return (self.name == other.name and + self.data_type == other.data_type and + self.description == other.description and + self.sort_order == other.sort_order and + self.is_partition == other.is_partition and + self.attributes == other.attributes and + self.badges == other.badges) + + def __repr__(self) -> str: + return f'{self.name}:{self.data_type}' + + +# TODO consider deprecating this for using TableMetadata directly +class ScrapedTableMetadata(object): + LAST_MODIFIED_KEY = 'lastModified' + DESCRIPTION_KEY = 'description' + TABLE_FORMAT_KEY = 'format' + + def __init__(self, schema: str, table: str): + self.schema: str = schema + self.table: str = table + self.table_detail: Optional[Dict] = None + self.view_detail: Optional[Dict] = None + self.is_view: bool = False + self.failed_to_scrape: bool = False + self.columns: Optional[List[ScrapedColumnMetadata]] = None + + def set_table_detail(self, table_detail: Dict) -> None: + self.table_detail = table_detail + self.is_view = False + self.failed_to_scrape = False + + def set_view_detail(self, view_detail: Dict) -> None: + self.view_detail = view_detail + self.is_view = True + self.failed_to_scrape = False + + def get_details(self) -> Optional[Dict]: + if self.is_view: + return self.view_detail + else: + return self.table_detail + + def get_full_table_name(self) -> str: + return self.schema + "." + self.table + + def set_failed_to_scrape(self) -> None: + self.failed_to_scrape = True + + def set_columns(self, column_list: List[ScrapedColumnMetadata]) -> None: + self.columns = column_list + + def get_last_modified(self) -> Optional[datetime]: + details = self.get_details() + if details and self.LAST_MODIFIED_KEY in details: + return details[self.LAST_MODIFIED_KEY] + else: + return None + + def get_table_description(self) -> Optional[str]: + details = self.get_details() + if details and self.DESCRIPTION_KEY in details: + return details[self.DESCRIPTION_KEY] + else: + return None + + def is_delta_table(self) -> bool: + details = self.get_details() + if details and self.TABLE_FORMAT_KEY in details: + return details[self.TABLE_FORMAT_KEY].lower() == 'delta' + else: + return False + + def __repr__(self) -> str: + return f'{self.schema}.{self.table}' + + +class DeltaLakeMetadataExtractor(Extractor): + """ + Extracts Delta Lake Metadata. + This requires a spark session to run that has a hive metastore populated with all of the delta tables + that you are interested in. + + By default, the extractor does not extract nested columns. Set the EXTRACT_NESTED_COLUMNS conf to True + if you would like nested columns extracted + """ + # CONFIG KEYS + DATABASE_KEY = "database" + # If you want to exclude specific schemas + EXCLUDE_LIST_SCHEMAS_KEY = "exclude_list" + # If you want to only include specific schemas + SCHEMA_LIST_KEY = "schema_list" + CLUSTER_KEY = "cluster" + # By default, this will only process and emit delta-lake tables, but it can support all hive table types. + DELTA_TABLES_ONLY = "delta_tables_only" + DEFAULT_CONFIG = ConfigFactory.from_dict({DATABASE_KEY: "delta", + EXCLUDE_LIST_SCHEMAS_KEY: [], + SCHEMA_LIST_KEY: [], + DELTA_TABLES_ONLY: True}) + PARTITION_COLUMN_TAG = 'is_partition' + + # For backwards compatibility, the delta lake extractor does not extract nested columns for indexing + # Set this to true in the conf if you would like nested columns & complex types fully extracted + EXTRACT_NESTED_COLUMNS = "extract_nested_columns" + + def init(self, conf: ConfigTree) -> None: + self.conf = conf.with_fallback(DeltaLakeMetadataExtractor.DEFAULT_CONFIG) + self._extract_iter = None # type: Union[None, Iterator] + self._cluster = self.conf.get_string(DeltaLakeMetadataExtractor.CLUSTER_KEY) + self._db = self.conf.get_string(DeltaLakeMetadataExtractor.DATABASE_KEY) + self.exclude_list = self.conf.get_list(DeltaLakeMetadataExtractor.EXCLUDE_LIST_SCHEMAS_KEY) + self.schema_list = self.conf.get_list(DeltaLakeMetadataExtractor.SCHEMA_LIST_KEY) + self.delta_tables_only = self.conf.get_bool(DeltaLakeMetadataExtractor.DELTA_TABLES_ONLY) + self.extract_nested_columns = self.conf.get_bool(DeltaLakeMetadataExtractor.EXTRACT_NESTED_COLUMNS, + default=False) + + def set_spark(self, spark: SparkSession) -> None: + self.spark = spark + + def extract(self) -> Union[TableMetadata, List[Tuple[Watermark, Watermark]], TableLastUpdated, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.delta_lake_table_metadata' + + def _get_extract_iter(self) -> Iterator[Union[TableMetadata, Watermark, TableLastUpdated, + None]]: + """ + Given either a list of schemas, or a list of exclude schemas, + it will query hive metastore and then access delta log + to get all of the metadata for your delta tables. It will produce: + - table and column metadata (including partition watermarks) + - last updated information + """ + if self.schema_list: + LOGGER.info("working on %s", self.schema_list) + tables = self.get_all_tables(self.schema_list) + else: + LOGGER.info("fetching all schemas") + LOGGER.info("Excluding: %s", self.exclude_list) + schemas = self.get_schemas(self.exclude_list) + LOGGER.info("working on %s", schemas) + tables = self.get_all_tables(schemas) + # TODO add the programmatic information as well? + scraped_tables = self.scrape_all_tables(tables) + for scraped_table in scraped_tables: + if not scraped_table: + continue + if self.delta_tables_only and not scraped_table.is_delta_table(): + LOGGER.info("Skipping none delta table %s", scraped_table.table) + continue + else: + yield self.create_table_metadata(scraped_table) + watermarks = self.create_table_watermarks(scraped_table) + if watermarks: + for watermark in watermarks: + yield watermark[0] + yield watermark[1] + last_updated = self.create_table_last_updated(scraped_table) + if last_updated: + yield last_updated + + def get_schemas(self, exclude_list: List[str]) -> List[str]: + '''Returns all schemas.''' + schemas = self.spark.catalog.listDatabases() + ret = [] + for schema in schemas: + if schema.name not in exclude_list: + ret.append(schema.name) + return ret + + def get_all_tables(self, schemas: List[str]) -> List[Table]: + '''Returns all tables.''' + ret = [] + for schema in schemas: + ret.extend(self.get_tables_for_schema(schema)) + return ret + + def get_tables_for_schema(self, schema: str) -> List[Table]: + '''Returns all tables for a specific schema.''' + return self.spark.catalog.listTables(schema) + + def scrape_all_tables(self, tables: List[Table]) -> List[Optional[ScrapedTableMetadata]]: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self.scrape_table, table) for table in tables] + scraped_tables = [f.result() for f in futures] + return scraped_tables + + def scrape_table(self, table: Table) -> Optional[ScrapedTableMetadata]: + '''Takes a table object and creates a scraped table metadata object.''' + met = ScrapedTableMetadata(schema=table.database, table=table.name) + table_name = met.get_full_table_name() + if table.tableType and table.tableType.lower() != 'view': + table_detail = self.scrape_table_detail(table_name) + if table_detail is None: + LOGGER.error("Failed to parse table " + table_name) + met.set_failed_to_scrape() + return None + else: + LOGGER.info("Successfully parsed table " + table_name) + met.set_table_detail(table_detail) + else: + view_detail = self.scrape_view_detail(table_name) + if view_detail is None: + LOGGER.error("Failed to parse view " + table_name) + met.set_failed_to_scrape() + return None + else: + LOGGER.info("Successfully parsed view " + table_name) + met.set_view_detail(view_detail) + columns = self.fetch_columns(met.schema, met.table) + if not columns: + LOGGER.error("Failed to parse columns for " + table_name) + return None + else: + met.set_columns(columns) + return met + + def scrape_table_detail(self, table_name: str) -> Optional[Dict]: + try: + table_details_df = self.spark.sql(f"describe detail {table_name}") + table_detail = table_details_df.collect()[0] + return table_detail.asDict() + except Exception as e: + LOGGER.error(e) + return None + + def scrape_view_detail(self, view_name: str) -> Optional[Dict]: + # TODO the blanket try catches need to be changed + describeExtendedOutput = [] + try: + describeExtendedOutput = self.spark.sql(f"describe extended {view_name}").collect() + except Exception as e: + LOGGER.error(e) + return None + view_detail = {} + startAdding = False + for row in describeExtendedOutput: + row_dict = row.asDict() + if startAdding: + view_detail[row_dict['col_name']] = row_dict['data_type'] + if "# Detailed Table" in row_dict['col_name']: + # Then start parsing + startAdding = True + return view_detail + + def fetch_columns(self, schema: str, table: str) -> List[ScrapedColumnMetadata]: + '''This fetches delta table columns, which unfortunately + in the general case cannot rely on spark.catalog.listColumns.''' + raw_columns = [] + field_dict: Dict[str, Any] = {} + table_name = f"{schema}.{table}" + try: + raw_columns = self.spark.sql(f"describe {table_name}").collect() + for field in self.spark.table(f"{table_name}").schema: + field_dict[field.name] = field + except (AnalysisException, ParseException) as e: + LOGGER.error(e) + return [] + parsed_columns: Dict[str, ScrapedColumnMetadata] = {} + partition_cols = False + sort_order = 0 + for row in raw_columns: + col_name = row['col_name'] + # NOTE: the behavior of describe has changed between spark 2 and spark 3 + if col_name == '' or '#' in col_name: + partition_cols = True + continue + if not partition_cols: + # Attempt to extract nested columns if conf value requests it + if self.extract_nested_columns \ + and col_name in field_dict \ + and self.is_complex_delta_type(field_dict[col_name].dataType): + sort_order = self._iterate_complex_type("", field_dict[col_name], parsed_columns, sort_order) + else: + column = ScrapedColumnMetadata( + name=row['col_name'], + description=row['comment'] if row['comment'] else None, + data_type=row['data_type'], + sort_order=sort_order, + badges=None + ) + parsed_columns[row['col_name']] = column + sort_order += 1 + else: + if row['data_type'] in parsed_columns: + LOGGER.debug(f"Adding partition column table for {row['data_type']}") + parsed_columns[row['data_type']].set_is_partition(True) + parsed_columns[row['data_type']].set_badges([PARTITION_BADGE]) + elif row['col_name'] in parsed_columns: + LOGGER.debug(f"Adding partition column table for {row['col_name']}") + parsed_columns[row['col_name']].set_is_partition(True) + parsed_columns[row['col_name']].set_badges([PARTITION_BADGE]) + return list(parsed_columns.values()) + + def _iterate_complex_type(self, + parent: str, + curr_field: Union[StructType, StructField, ArrayType, MapType], + parsed_columns: Dict, + total_cols: int) -> int: + col_name = parent + if self.is_struct_field(curr_field): + if len(parent) > 0: + col_name = f"{parent}.{curr_field.name}" + else: + col_name = curr_field.name + + parsed_columns[col_name] = ScrapedColumnMetadata( + name=col_name, + data_type=curr_field.dataType.simpleString(), + sort_order=total_cols, + description=None, + ) + total_cols += 1 + if self.is_complex_delta_type(curr_field.dataType): + total_cols = self._iterate_complex_type(col_name, curr_field.dataType, parsed_columns, total_cols) + + if self.is_complex_delta_type(curr_field): + if self.is_struct_type(curr_field): + for field in curr_field: + total_cols = self._iterate_complex_type(col_name, field, parsed_columns, total_cols) + elif self.is_array_type(curr_field) and self.is_complex_delta_type(curr_field.elementType): + total_cols = self._iterate_complex_type(col_name, curr_field.elementType, parsed_columns, total_cols) + elif self.is_map_type(curr_field) and self.is_complex_delta_type(curr_field.valueType): + total_cols = self._iterate_complex_type(col_name, curr_field.valueType, parsed_columns, total_cols) + + return total_cols + + def create_table_metadata(self, table: ScrapedTableMetadata) -> TableMetadata: + '''Creates the amundsen table metadata object from the ScrapedTableMetadata object.''' + amundsen_columns = [] + if table.columns: + for column in table.columns: + amundsen_columns.append( + ColumnMetadata(name=column.name, + description=column.description, + col_type=column.data_type, + sort_order=column.sort_order, + badges=column.badges) + ) + description = table.get_table_description() + return TableMetadata(self._db, + self._cluster, + table.schema, + table.table, + description, + amundsen_columns, + table.is_view) + + def create_table_last_updated(self, table: ScrapedTableMetadata) -> Optional[TableLastUpdated]: + '''Creates the amundsen table last updated metadata object from the ScrapedTableMetadata object.''' + last_modified = table.get_last_modified() + if last_modified: + return TableLastUpdated(table_name=table.table, + last_updated_time_epoch=int(last_modified.timestamp()), + schema=table.schema, + db=self._db, + cluster=self._cluster) + else: + return None + + def is_complex_delta_type(self, delta_type: Any) -> bool: + return isinstance(delta_type, StructType) or \ + isinstance(delta_type, ArrayType) or \ + isinstance(delta_type, MapType) + + def is_struct_type(self, delta_type: Any) -> bool: + return isinstance(delta_type, StructType) + + def is_struct_field(self, delta_type: Any) -> bool: + return isinstance(delta_type, StructField) + + def is_array_type(self, delta_type: Any) -> bool: + return isinstance(delta_type, ArrayType) + + def is_map_type(self, delta_type: Any) -> bool: + return isinstance(delta_type, MapType) + + def create_table_watermarks(self, table: ScrapedTableMetadata) -> Optional[List[Tuple[Watermark, Watermark]]]: # noqa c901 + """ + Creates the watermark objects that reflect the highest and lowest values in the partition columns + """ + def _is_show_partitions_supported(t: ScrapedTableMetadata) -> bool: + try: + self.spark.sql(f'show partitions {t.schema}.{t.table}') + return True + except Exception as e: + # pyspark.sql.utils.AnalysisException: SHOW PARTITIONS is not allowed on a table that is not partitioned + LOGGER.warning(e) + return False + + def _fetch_minmax(table: ScrapedTableMetadata, partition_column: str) -> Tuple[str, str]: + LOGGER.info(f'Fetching partition info for {partition_column} in {table.schema}.{table.table}') + min_water = "" + max_water = "" + try: + if is_show_partitions_supported: + LOGGER.info('Using SHOW PARTITION') + min_water = str( + self + .spark + .sql(f'show partitions {table.schema}.{table.table}') + .orderBy(partition_column, ascending=True) + .first()[partition_column]) + max_water = str( + self + .spark + .sql(f'show partitions {table.schema}.{table.table}') + .orderBy(partition_column, ascending=False) + .first()[partition_column]) + else: + LOGGER.info('Using DESCRIBE EXTENDED') + part_info = (self + .spark + .sql(f'describe extended {table.schema}.{table.table} {partition_column}') + .collect() + ) + minmax = {} + for mm in list(filter(lambda x: x['info_name'] in ['min', 'max'], part_info)): + minmax[mm['info_name']] = mm['info_value'] + min_water = minmax['min'] + max_water = minmax['max'] + except Exception as e: + LOGGER.warning(f'Failed fetching partition watermarks: {e}') + return max_water, min_water + + if not table.table_detail: + LOGGER.info(f'No table details found in {table}, skipping') + return None + + if 'partitionColumns' not in table.table_detail or len(table.table_detail['partitionColumns']) < 1: + LOGGER.info(f'No partitions found in {table}, skipping') + return None + + is_show_partitions_supported: bool = _is_show_partitions_supported(table) + + if not is_show_partitions_supported: + LOGGER.info('Analyzing table, this can take a while...') + partition_columns = ','.join(table.table_detail['partitionColumns']) + self.spark.sql( + f"analyze table {table.schema}.{table.table} compute statistics for columns {partition_columns}") + + # It makes little sense to get watermarks from a string value, with no concept of high and low. + # Just imagine a dataset with a partition by country... + valid_types = ['int', 'float', 'date', 'datetime'] + if table.columns: + _table_columns = table.columns + else: + _table_columns = [] + columns_with_valid_type = list(map(lambda n: n.name, + filter(lambda d: str(d.data_type).lower() in valid_types, _table_columns) + ) + ) + + r = [] + for partition_column in table.table_detail['partitionColumns']: + if partition_column not in columns_with_valid_type: + continue + + last, first = _fetch_minmax(table, partition_column) + low = Watermark( + create_time=table.table_detail['createdAt'], + database=self._db, + schema=table.schema, + table_name=table.table, + part_name=f'{partition_column}={first}', + part_type='low_watermark', + cluster=self._cluster) + high = Watermark( + create_time=table.table_detail['createdAt'], + database=self._db, + schema=table.schema, + table_name=table.table, + part_name=f'{partition_column}={last}', + part_type='high_watermark', + cluster=self._cluster) + r.append((high, low)) + return r diff --git a/databuilder/databuilder/extractor/dremio_metadata_extractor.py b/databuilder/databuilder/extractor/dremio_metadata_extractor.py new file mode 100644 index 0000000000..71ac03651e --- /dev/null +++ b/databuilder/databuilder/extractor/dremio_metadata_extractor.py @@ -0,0 +1,178 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree +from pyodbc import connect + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class DremioMetadataExtractor(Extractor): + ''' + Extracts Dremio table and column metadata from underlying INFORMATION_SCHEMA table + + Requirements: + pyodbc & Dremio driver + ''' + + SQL_STATEMENT = ''' + SELECT + nested_1.COLUMN_NAME AS col_name, + CAST(NULL AS VARCHAR) AS col_description, + nested_1.DATA_TYPE AS col_type, + nested_1.ORDINAL_POSITION AS col_sort_order, + nested_1.TABLE_CATALOG AS database, + '{cluster}' AS cluster, + nested_1.TABLE_SCHEMA AS schema, + nested_1.TABLE_NAME AS name, + CAST(NULL AS VARCHAR) AS description, + CASE WHEN nested_0.TABLE_TYPE='VIEW' THEN TRUE ELSE FALSE END AS is_view + FROM ( + SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE + FROM INFORMATION_SCHEMA."TABLES" + ) nested_0 + RIGHT JOIN ( + SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, ORDINAL_POSITION + FROM INFORMATION_SCHEMA."COLUMNS" + ) nested_1 ON nested_0.TABLE_NAME = nested_1.TABLE_NAME + AND nested_0.TABLE_SCHEMA = nested_1.TABLE_SCHEMA + AND nested_0.TABLE_CATALOG = nested_1.TABLE_CATALOG + {where_stmt} + ''' + + # Config keys + DREMIO_USER_KEY = 'user_key' + DREMIO_PASSWORD_KEY = 'password_key' + DREMIO_HOST_KEY = 'host_key' + DREMIO_PORT_KEY = 'port_key' + DREMIO_DRIVER_KEY = 'driver_key' + DREMIO_CLUSTER_KEY = 'cluster_key' + DREMIO_EXCLUDE_SYS_TABLES_KEY = 'exclude_system_tables' + DREMIO_EXCLUDE_PDS_TABLES_KEY = 'exclude_pds_tables' + + # Default values + DEFAULT_AUTH_USER = 'dremio_auth_user' + DEFAULT_AUTH_PW = 'dremio_auth_pw' + DEFAULT_HOST = 'localhost' + DEFAULT_PORT = '31010' + DEFAULT_DRIVER = 'DSN=Dremio Connector' + DEFAULT_CLUSTER_NAME = 'Production' + DEFAULT_EXCLUDE_SYS_TABLES = True + DEFAULT_EXCLUDE_PDS_TABLES = False + + # Default config + DEFAULT_CONFIG = ConfigFactory.from_dict({ + DREMIO_USER_KEY: DEFAULT_AUTH_USER, + DREMIO_PASSWORD_KEY: DEFAULT_AUTH_PW, + DREMIO_HOST_KEY: DEFAULT_HOST, + DREMIO_PORT_KEY: DEFAULT_PORT, + DREMIO_DRIVER_KEY: DEFAULT_DRIVER, + DREMIO_CLUSTER_KEY: DEFAULT_CLUSTER_NAME, + DREMIO_EXCLUDE_SYS_TABLES_KEY: DEFAULT_EXCLUDE_SYS_TABLES, + DREMIO_EXCLUDE_PDS_TABLES_KEY: DEFAULT_EXCLUDE_PDS_TABLES + }) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DremioMetadataExtractor.DEFAULT_CONFIG) + + exclude_sys_tables = conf.get_bool(DremioMetadataExtractor.DREMIO_EXCLUDE_SYS_TABLES_KEY) + exclude_pds_tables = conf.get_bool(DremioMetadataExtractor.DREMIO_EXCLUDE_PDS_TABLES_KEY) + if exclude_sys_tables and exclude_pds_tables: + where_stmt = ("WHERE nested_0.TABLE_TYPE != 'SYSTEM_TABLE' AND " + "nested_0.TABLE_TYPE != 'TABLE';") + elif exclude_sys_tables: + where_stmt = "WHERE nested_0.TABLE_TYPE != 'SYSTEM_TABLE';" + elif exclude_pds_tables: + where_stmt = "WHERE nested_0.TABLE_TYPE != 'TABLE';" + else: + where_stmt = ';' + + self._cluster = conf.get_string(DremioMetadataExtractor.DREMIO_CLUSTER_KEY) + + self._cluster = conf.get_string(DremioMetadataExtractor.DREMIO_CLUSTER_KEY) + + self.sql_stmt = DremioMetadataExtractor.SQL_STATEMENT.format( + cluster=self._cluster, + where_stmt=where_stmt + ) + + LOGGER.info('SQL for Dremio metadata: %s', self.sql_stmt) + + self._pyodbc_cursor = connect( + conf.get_string(DremioMetadataExtractor.DREMIO_DRIVER_KEY), + uid=conf.get_string(DremioMetadataExtractor.DREMIO_USER_KEY), + pwd=conf.get_string(DremioMetadataExtractor.DREMIO_PASSWORD_KEY), + host=conf.get_string(DremioMetadataExtractor.DREMIO_HOST_KEY), + port=conf.get_string(DremioMetadataExtractor.DREMIO_PORT_KEY), + autocommit=True).cursor() + + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.dremio' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + ''' + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + ''' + for _, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata( + row['col_name'], + row['col_description'], + row['col_type'], + row['col_sort_order']) + ) + + yield TableMetadata(last_row['database'], + last_row['cluster'], + last_row['schema'], + last_row['name'], + last_row['description'], + columns, + last_row['is_view'] == 'true') + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + ''' + Provides iterator of result row from SQLAlchemy extractor + :return: + ''' + + for row in self._pyodbc_cursor.execute(self.sql_stmt): + yield dict(zip([c[0] for c in self._pyodbc_cursor.description], row)) + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + ''' + Table key consists of schema and table name + :param row: + :return: + ''' + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/druid_metadata_extractor.py b/databuilder/databuilder/extractor/druid_metadata_extractor.py new file mode 100644 index 0000000000..cd0316f9d8 --- /dev/null +++ b/databuilder/databuilder/extractor/druid_metadata_extractor.py @@ -0,0 +1,112 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import textwrap +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class DruidMetadataExtractor(Extractor): + """ + Extracts Druid table and column metadata from druid using dbapi extractor + """ + SQL_STATEMENT = textwrap.dedent(""" + SELECT + TABLE_SCHEMA as schema, + TABLE_NAME as name, + COLUMN_NAME as col_name, + DATA_TYPE as col_type, + ORDINAL_POSITION as col_sort_order + FROM INFORMATION_SCHEMA.COLUMNS + {where_clause_suffix} + order by TABLE_SCHEMA, TABLE_NAME, CAST(ORDINAL_POSITION AS int) + """) + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster' + + DEFAULT_CONFIG = ConfigFactory.from_dict({WHERE_CLAUSE_SUFFIX_KEY: ' ', + CLUSTER_KEY: 'gold'}) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DruidMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(DruidMetadataExtractor.CLUSTER_KEY) + + self.sql_stmt = DruidMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(DruidMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY, + default='')) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.druid_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + # no table description and column description + for row in group: + last_row = row + columns.append(ColumnMetadata(name=row['col_name'], + description='', + col_type=row['col_type'], + sort_order=row['col_sort_order'])) + yield TableMetadata(database='druid', + cluster=self._cluster, + schema=last_row['schema'], + name=last_row['name'], + description='', + columns=columns) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from dbapi extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/es_base_extractor.py b/databuilder/databuilder/extractor/es_base_extractor.py new file mode 100644 index 0000000000..69015260fc --- /dev/null +++ b/databuilder/databuilder/extractor/es_base_extractor.py @@ -0,0 +1,151 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import abc +from typing import ( + Any, Dict, Iterator, List, Optional, Union, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata + + +class ElasticsearchBaseExtractor(Extractor): + """ + Extractor to extract index metadata from Elasticsearch + + By default, the extractor does not add sort_order to columns. Set ELASTICSEARCH_CORRECT_SORT_ORDER conf to True + for columns to have correct sort order. + + Set ELASTICSEARCH_TIME_FIELD to the name of the field representing time. + """ + + ELASTICSEARCH_CLIENT_CONFIG_KEY = 'client' + ELASTICSEARCH_EXTRACT_TECHNICAL_DETAILS = 'extract_technical_details' + + # For backwards compatibility, the Elasticsearch extractor does not add sort_order to columns by default. + # Set this to true in the conf for columns to have correct sort order. + ELASTICSEARCH_CORRECT_SORT_ORDER = 'correct_sort_order' + + # Set this to the name of the field representing time. + ELASTICSEARCH_TIME_FIELD = 'time_field' + + CLUSTER = 'cluster' + SCHEMA = 'schema' + + def __init__(self) -> None: + super(ElasticsearchBaseExtractor, self).__init__() + + def init(self, conf: ConfigTree) -> None: + self.conf = conf + self._extract_iter = self._get_extract_iter() + + self.es = self.conf.get(ElasticsearchBaseExtractor.ELASTICSEARCH_CLIENT_CONFIG_KEY) + + def _get_es_version(self) -> str: + return self.es.info().get('version').get('number') + + def _get_indexes(self) -> Dict: + result = dict() + + try: + _indexes = self.es.indices.get('*') + + for k, v in _indexes.items(): + if not k.startswith('.'): + result[k] = v + except Exception: + pass + + return result + + def _get_index_creation_date(self, index_metadata: Dict) -> Optional[float]: + try: + return float(index_metadata.get('settings', dict()).get('index').get('creation_date')) + except Exception: + return None + + def _get_index_mapping_properties(self, index: Dict) -> Optional[Dict]: + mappings = index.get('mappings', dict()) + + # Mapping types were removed in Elasticsearch 7. As a result, index mappings are formatted differently. + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/removal-of-types.html + version = self._get_es_version() + + try: + if int(version.split('.')[0]) >= 7: + properties = mappings.get('properties', dict()) + else: + properties = list(mappings.values())[0].get('properties', dict()) + except Exception: + properties = dict() + + return properties + + def _get_attributes(self, + input_mapping: Dict, + parent_col_name: str = '', + separator: str = '.') -> List[ColumnMetadata]: + cols: List[ColumnMetadata] = [] + + for col_name, col_mapping in input_mapping.items(): + qualified_col_name = str(parent_col_name) + separator + col_name if parent_col_name else col_name + if isinstance(col_mapping, dict): + if col_mapping.__contains__('properties'): + # Need to recurse + inner_mapping: Dict = col_mapping.get('properties', {}) + cols.extend(self._get_attributes(input_mapping=inner_mapping, + parent_col_name=qualified_col_name, + separator=separator)) + else: + cols.append(ColumnMetadata(name=qualified_col_name, + description='', + col_type=col_mapping.get('type', ''), + sort_order=0)) + + return cols + + def extract(self) -> Any: + try: + result = next(self._extract_iter) + + return result + except StopIteration: + return None + + @property + def database(self) -> str: + return 'elasticsearch' + + @property + def cluster(self) -> str: + return self.conf.get(ElasticsearchBaseExtractor.CLUSTER) + + @property + def schema(self) -> str: + return self.conf.get(ElasticsearchBaseExtractor.SCHEMA) + + @property + def _extract_technical_details(self) -> bool: + try: + return self.conf.get(ElasticsearchBaseExtractor.ELASTICSEARCH_EXTRACT_TECHNICAL_DETAILS) + except Exception: + return False + + @property + def _correct_sort_order(self) -> bool: + try: + return self.conf.get(ElasticsearchBaseExtractor.ELASTICSEARCH_CORRECT_SORT_ORDER) + except Exception: + return False + + # Default time field is @timestamp to match ECS + # See https://www.elastic.co/guide/en/ecs/master/ecs-base.html + @property + def _time_field(self) -> str: + return self.conf.get(ElasticsearchBaseExtractor.ELASTICSEARCH_TIME_FIELD, '@timestamp') + + @abc.abstractmethod + def _get_extract_iter(self) -> Iterator[Union[Any, None]]: + pass diff --git a/databuilder/databuilder/extractor/es_column_stats_extractor.py b/databuilder/databuilder/extractor/es_column_stats_extractor.py new file mode 100644 index 0000000000..a4aa514085 --- /dev/null +++ b/databuilder/databuilder/extractor/es_column_stats_extractor.py @@ -0,0 +1,82 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import ( + Any, Dict, Iterator, List, Set, Union, +) + +from databuilder.extractor.es_base_extractor import ElasticsearchBaseExtractor +from databuilder.models.table_stats import TableColumnStats + + +class ElasticsearchColumnStatsExtractor(ElasticsearchBaseExtractor): + """ + Extractor to extract stats for Elasticsearch mapping attributes. + """ + + def get_scope(self) -> str: + return 'extractor.es_column_stats' + + def _get_index_stats(self, index_name: str, fields: List[str]) -> List[Dict[str, Any]]: + query = { + "size": 0, + "aggs": { + "stats": { + "matrix_stats": { + "fields": fields + } + } + } + } + + _data = self.es.search(index=index_name, body=query) + + data = _data.get('aggregations', dict()).get('stats', dict()).get('fields', list()) + + return data + + def _render_column_stats(self, index_name: str, spec: Dict[str, Any]) -> List[TableColumnStats]: + result: List[TableColumnStats] = [] + + col_name = spec.pop('name') + + for stat_name, stat_val in spec.items(): + if isinstance(stat_val, dict) or isinstance(stat_val, list): + continue + elif stat_val == 'NaN': + continue + + stat = TableColumnStats(table_name=index_name, + col_name=col_name, + stat_name=stat_name, + stat_val=stat_val, + start_epoch='0', + end_epoch='0', + db=self.database, + cluster=self.cluster, + schema=self.schema) + + result.append(stat) + + return result + + @property + def _allowed_types(self) -> Set[str]: + return set(['long', 'integer', 'short', 'byte', 'double', + 'float', 'half_float', 'scaled_float', 'unsigned_long']) + + def _get_extract_iter(self) -> Iterator[Union[TableColumnStats, None]]: + indexes: Dict = self._get_indexes() + for index_name, index_metadata in indexes.items(): + properties = self._get_index_mapping_properties(index_metadata) or dict() + + fields = [name for name, spec in properties.items() if spec['type'] in self._allowed_types] + + specifications = self._get_index_stats(index_name, fields) + + for spec in specifications: + stats = self._render_column_stats(index_name, spec) + + for stat in stats: + yield stat diff --git a/databuilder/databuilder/extractor/es_last_updated_extractor.py b/databuilder/databuilder/extractor/es_last_updated_extractor.py new file mode 100644 index 0000000000..eb254fae74 --- /dev/null +++ b/databuilder/databuilder/extractor/es_last_updated_extractor.py @@ -0,0 +1,51 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import time +from typing import Any + +from pyhocon import ConfigTree + +from databuilder.extractor.generic_extractor import GenericExtractor + + +class EsLastUpdatedExtractor(GenericExtractor): + """ + Extractor to extract last updated timestamp for Datastore and Es + """ + + def init(self, conf: ConfigTree) -> None: + """ + Receives a list of dictionaries which is used for extraction + :param conf: + :return: + """ + self.conf = conf + + model_class = conf.get('model_class', None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + last_updated_timestamp = int(time.time()) + result = {'timestamp': last_updated_timestamp} + results = [self.model_class(**result)] + self._iter = iter(results) + else: + raise RuntimeError('model class needs to be provided!') + + def extract(self) -> Any: + """ + Fetch one sql result row, convert to {model_class} if specified before + returning. + :return: + """ + try: + result = next(self._iter) + return result + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.es_last_updated' diff --git a/databuilder/databuilder/extractor/es_metadata_extractor.py b/databuilder/databuilder/extractor/es_metadata_extractor.py new file mode 100644 index 0000000000..2550484b92 --- /dev/null +++ b/databuilder/databuilder/extractor/es_metadata_extractor.py @@ -0,0 +1,79 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import json +from typing import ( + Dict, Iterator, Optional, Union, +) + +from databuilder.extractor.es_base_extractor import ElasticsearchBaseExtractor +from databuilder.models.table_metadata import TableMetadata + + +class ElasticsearchMetadataExtractor(ElasticsearchBaseExtractor): + """ + Extractor to extract index metadata from Elasticsearch + """ + + def get_scope(self) -> str: + return 'extractor.es_metadata' + + def _render_programmatic_description(self, input: Optional[Dict]) -> Optional[str]: + if input: + result = f"""```\n{json.dumps(input, indent=2)}\n```""" + + return result + else: + return None + + def _get_extract_iter(self) -> Iterator[Union[TableMetadata, None]]: + indexes: Dict = self._get_indexes() + + for index_name, index_metadata in indexes.items(): + properties = self._get_index_mapping_properties(index_metadata) or dict() + + columns = self._get_attributes(input_mapping=properties) + + # The columns are already sorted, but the sort_order needs to be added to each column metadata entry + if self._correct_sort_order: + for index in range(len(columns)): + columns[index].sort_order = index + + table_metadata = TableMetadata(database=self.database, + cluster=self.cluster, + schema=self.schema, + name=index_name, + description=None, + columns=columns, + is_view=False, + tags=None, + description_source=None) + + yield table_metadata + + if self._extract_technical_details: + _settings = index_metadata.get('settings', dict()) + _aliases = index_metadata.get('aliases', dict()) + + settings = self._render_programmatic_description(_settings) + aliases = self._render_programmatic_description(_aliases) + + if aliases: + yield TableMetadata(database=self.database, + cluster=self.cluster, + schema=self.schema, + name=index_name, + description=aliases, + columns=columns, + is_view=False, + tags=None, + description_source='aliases') + if settings: + yield TableMetadata(database=self.database, + cluster=self.cluster, + schema=self.schema, + name=index_name, + description=settings, + columns=columns, + is_view=False, + tags=None, + description_source='settings') diff --git a/databuilder/databuilder/extractor/es_watermark_extractor.py b/databuilder/databuilder/extractor/es_watermark_extractor.py new file mode 100644 index 0000000000..96007cd4d2 --- /dev/null +++ b/databuilder/databuilder/extractor/es_watermark_extractor.py @@ -0,0 +1,79 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from typing import ( + Dict, Iterator, Optional, Tuple, Union, +) + +from databuilder.extractor.es_base_extractor import ElasticsearchBaseExtractor +from databuilder.models.watermark import Watermark + + +class ElasticsearchWatermarkExtractor(ElasticsearchBaseExtractor): + """ + Extractor to extract index watermarks from Elasticsearch + """ + + def get_scope(self) -> str: + return 'extractor.es_watermark' + + # Internally, Elasticsearch stores dates as numbers representing milliseconds since the epoch, + # so the agg result is expected to be floats. + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/date.html#date + def _get_index_watermark_bounds(self, index_name: str) -> Optional[Tuple[float, float]]: + try: + search_result = self.es.search( + index=index_name, + size=0, + aggs={ + 'min_watermark': {'min': {'field': self._time_field}}, + 'max_watermark': {'max': {'field': self._time_field}} + } + ) + watermark_min = search_result.get('aggregations').get('min_watermark').get('value') + watermark_max = search_result.get('aggregations').get('max_watermark').get('value') + if watermark_min is not None and watermark_max is not None: + return float(watermark_min), float(watermark_max) + except Exception: + pass + + return None + + def _get_extract_iter(self) -> Iterator[Union[Watermark, None]]: + # Get all the indices + indices: Dict = self._get_indexes() + + # Iterate over indices + for index_name, index_metadata in indices.items(): + creation_date: Optional[float] = self._get_index_creation_date(index_metadata) + watermark_bounds: Optional[Tuple[float, float]] = self._get_index_watermark_bounds(index_name=index_name) + watermark_min: Optional[float] = None if watermark_bounds is None else watermark_bounds[0] + watermark_max: Optional[float] = None if watermark_bounds is None else watermark_bounds[1] + + if creation_date is None or watermark_min is None or watermark_max is None: + continue + + creation_date_str: str = datetime.fromtimestamp(creation_date / 1000).strftime('%Y-%m-%d %H:%M:%S') + watermark_min_str: str = datetime.fromtimestamp(watermark_min / 1000).strftime('%Y-%m-%d') + watermark_max_str: str = datetime.fromtimestamp(watermark_max / 1000).strftime('%Y-%m-%d') + + yield Watermark( + database=self.database, + cluster=self.cluster, + schema=self.schema, + table_name=index_name, + create_time=creation_date_str, + part_name=f'{self._time_field}={watermark_min_str}', + part_type='low_watermark' + ) + + yield Watermark( + database=self.database, + cluster=self.cluster, + schema=self.schema, + table_name=index_name, + create_time=creation_date_str, + part_name=f'{self._time_field}={watermark_max_str}', + part_type='high_watermark' + ) diff --git a/databuilder/databuilder/extractor/eventbridge_extractor.py b/databuilder/databuilder/extractor/eventbridge_extractor.py new file mode 100644 index 0000000000..938ec1c3a7 --- /dev/null +++ b/databuilder/databuilder/extractor/eventbridge_extractor.py @@ -0,0 +1,220 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import ( + Any, Dict, Iterator, List, Optional, Union, +) + +import boto3 +import jsonref +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +LOGGER = logging.getLogger(__name__) + + +class EventBridgeExtractor(Extractor): + """ + Extracts the latest version of all schemas from a given AWS EventBridge schema registry + """ + + REGION_NAME_KEY = "region_name" + REGISTRY_NAME_KEY = "registry_name" + DEFAULT_CONFIG = ConfigFactory.from_dict( + {REGION_NAME_KEY: "us-east-1", REGISTRY_NAME_KEY: "aws.events"} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(EventBridgeExtractor.DEFAULT_CONFIG) + + boto3.setup_default_session( + region_name=conf.get(EventBridgeExtractor.REGION_NAME_KEY) + ) + self._schemas = boto3.client("schemas") + + self._registry_name = conf.get(EventBridgeExtractor.REGISTRY_NAME_KEY) + + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter(self._registry_name) + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return "extractor.eventbridge" + + def _get_extract_iter(self, registry_name: str) -> Iterator[TableMetadata]: + """ + It gets all the schemas and yields TableMetadata + :return: + """ + for schema_desc in self._get_raw_extract_iter(registry_name): + if "Content" not in schema_desc: + LOGGER.warning( + f"skipped malformatted schema: {jsonref.dumps(schema_desc)}" + ) + continue + + content = jsonref.loads(schema_desc["Content"]) + + if content.get("openapi", "") == "3.0.0": # NOTE: OpenAPI 3.0 + title = content.get("info", {}).get("title", "") + for schema_name, schema in ( + content.get("components", {}).get("schemas", {}).items() + ): + table = EventBridgeExtractor._build_table( + schema, + schema_name, + registry_name, + title, + content.get("description", None), + ) + + if table is None: + continue + + yield table + elif ( + content.get("$schema", "") == "http://json-schema.org/draft-04/schema#" + ): # NOTE: JSON Schema Draft 4 + title = content.get("title", "") + + for schema_name, schema in content.get("definitions", {}).items(): + table = EventBridgeExtractor._build_table( + schema, + schema_name, + registry_name, + title, + schema.get("description", None), + ) + + if table is None: + continue + + yield table + + table = EventBridgeExtractor._build_table( + content, + "Root", + registry_name, + title, + content.get("description", None), + ) + + if table is None: + continue + + yield table + + else: + LOGGER.warning( + f"skipped unsupported schema format: {jsonref.dumps(schema_desc)}" + ) + continue + + def _get_raw_extract_iter(self, registry_name: str) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of results row from schemas client + :return: + """ + schemas_descs = self._search_schemas(registry_name) + return iter(schemas_descs) + + def _search_schemas(self, registry_name: str) -> List[Dict[str, Any]]: + """ + Get all schemas descriptions. + """ + schemas_names = [] + paginator = self._schemas.get_paginator("list_schemas") + for result in paginator.paginate(RegistryName=registry_name): + for schema in result["Schemas"]: + schemas_names.append(schema["SchemaName"]) + + schemas_descs = [] + for schema_name in schemas_names: + schema_versions = [] + paginator = self._schemas.get_paginator("list_schema_versions") + for result in paginator.paginate( + RegistryName=registry_name, SchemaName=schema_name + ): + schema_versions += result["SchemaVersions"] + latest_schema_version = EventBridgeExtractor._get_latest_schema_version( + schema_versions + ) + + schema_desc = self._schemas.describe_schema( + RegistryName=registry_name, + SchemaName=schema_name, + SchemaVersion=latest_schema_version, + ) + + schemas_descs.append(schema_desc) + + return schemas_descs + + @staticmethod + def _build_table( + schema: Dict[str, Any], + schema_name: str, + registry_name: str, + title: str, + description: str, + ) -> Optional[TableMetadata]: + columns = [] + for i, (column_name, properties) in enumerate( + schema.get("properties", {}).items() + ): + columns.append( + ColumnMetadata( + column_name, + properties.get("description", None), + EventBridgeExtractor._get_property_type(properties), + i, + ) + ) + + if len(columns) == 0: + LOGGER.warning( + f"skipped schema with primitive type: " + f"{schema_name}: {jsonref.dumps(schema)}" + ) + return None + + return TableMetadata( + "eventbridge", registry_name, title, schema_name, description, columns, + ) + + @staticmethod + def _get_latest_schema_version(schema_versions: List[Dict[str, Any]]) -> str: + versions = [] + for info in schema_versions: + version = int(info["SchemaVersion"]) + versions.append(version) + return str(max(versions)) + + @staticmethod + def _get_property_type(schema: Dict) -> str: + if "type" not in schema: + return "object" + + if schema["type"] == "object": + properties = [ + f"{name}:{EventBridgeExtractor._get_property_type(_schema)}" + for name, _schema in schema.get("properties", {}).items() + ] + if len(properties) > 0: + return "struct<" + ",".join(properties) + ">" + return "struct" + elif schema["type"] == "array": + items = EventBridgeExtractor._get_property_type(schema.get("items", {})) + return "array<" + items + ">" + else: + if "format" in schema: + return f"{schema['type']}[{schema['format']}]" + return schema["type"] diff --git a/databuilder/databuilder/extractor/feast_extractor.py b/databuilder/databuilder/extractor/feast_extractor.py new file mode 100644 index 0000000000..62a100bd82 --- /dev/null +++ b/databuilder/databuilder/extractor/feast_extractor.py @@ -0,0 +1,137 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from typing import Iterator, Union + +from feast import FeatureStore, FeatureView +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class FeastExtractor(Extractor): + """ + Extracts feature tables from Feast feature store file. Since Feast is + a metadata store (and not the database itself), it maps the + following attributes: + + * a database is name of feast project + * table name is a name of the feature view + * columns are features stored in the feature view + """ + + FEAST_REPOSITORY_PATH = "/path/to/repository" + DESCRIBE_FEATURE_VIEWS = "describe_feature_views" + DEFAULT_CONFIG = ConfigFactory.from_dict( + {FEAST_REPOSITORY_PATH: ".", DESCRIBE_FEATURE_VIEWS: True} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(FeastExtractor.DEFAULT_CONFIG) + self._feast_repository_path = conf.get_string( + FeastExtractor.FEAST_REPOSITORY_PATH + ) + self._describe_feature_views = conf.get_bool( + FeastExtractor.DESCRIBE_FEATURE_VIEWS + ) + self._feast = FeatureStore(repo_path=self._feast_repository_path) + self._extract_iter: Union[None, Iterator] = None + + def get_scope(self) -> str: + return "extractor.feast" + + def extract(self) -> Union[TableMetadata, None]: + """ + For every feature table from Feast, a multiple objets are extracted: + + 1. TableMetadata with feature view description + 2. Programmatic Description of the feature view, containing + metadata - date of creation and labels + 3. Programmatic Description with Batch Source specification + 4. (if applicable) Programmatic Description with Stream Source + specification + """ + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + for feature_view in self._feast.list_feature_views(): + yield from self._extract_feature_view(feature_view) + + def _extract_feature_view( + self, feature_view: FeatureView + ) -> Iterator[TableMetadata]: + columns = [] + for index, entity_name in enumerate(feature_view.entities): + entity = self._feast.get_entity(entity_name) + columns.append( + ColumnMetadata( + entity.name, entity.description, entity.value_type.name, index + ) + ) + + for index, feature in enumerate(feature_view.features): + columns.append( + ColumnMetadata( + feature.name, + None, + feature.dtype.name, + len(feature_view.entities) + index, + ) + ) + + yield TableMetadata( + "feast", + self._feast.config.provider, + self._feast.project, + feature_view.name, + None, + columns, + ) + + if self._describe_feature_views: + description = str() + if feature_view.created_timestamp: + created_at = datetime.utcfromtimestamp( + feature_view.created_timestamp.timestamp() + ) + description = f"* Created at **{created_at}**\n" + + if feature_view.tags: + description += "* Tags:\n" + for key, value in feature_view.tags.items(): + description += f" * {key}: **{value}**\n" + + yield TableMetadata( + "feast", + self._feast.config.provider, + self._feast.project, + feature_view.name, + description, + description_source="feature_view_details", + ) + + yield TableMetadata( + "feast", + self._feast.config.provider, + self._feast.project, + feature_view.name, + f"```\n{str(feature_view.batch_source.to_proto())}```", + description_source="batch_source", + ) + + if feature_view.stream_source: + yield TableMetadata( + "feast", + self._feast.config.provider, + self._feast.project, + feature_view.name, + f"```\n{str(feature_view.stream_source.to_proto())}```", + description_source="stream_source", + ) diff --git a/databuilder/databuilder/extractor/generic_extractor.py b/databuilder/databuilder/extractor/generic_extractor.py new file mode 100644 index 0000000000..979849816e --- /dev/null +++ b/databuilder/databuilder/extractor/generic_extractor.py @@ -0,0 +1,52 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import Any, Iterable + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor + + +class GenericExtractor(Extractor): + """ + Extractor to extract any arbitrary values from users. + """ + EXTRACTION_ITEMS = 'extraction_items' + + def init(self, conf: ConfigTree) -> None: + """ + Receives a list of dictionaries which is used for extraction + :param conf: + :return: + """ + self.conf = conf + self.values: Iterable[Any] = conf.get(GenericExtractor.EXTRACTION_ITEMS) + + model_class = conf.get('model_class', None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + results = [self.model_class(**result) + for result in self.values] + + self._iter = iter(results) + else: + self._iter = iter(self.values) + + def extract(self) -> Any: + """ + Fetch one sql result row, convert to {model_class} if specified before + returning. + :return: + """ + try: + result = next(self._iter) + return result + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.generic' diff --git a/databuilder/databuilder/extractor/generic_usage_extractor.py b/databuilder/databuilder/extractor/generic_usage_extractor.py new file mode 100644 index 0000000000..44ef15071a --- /dev/null +++ b/databuilder/databuilder/extractor/generic_usage_extractor.py @@ -0,0 +1,113 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +The Generic Usage Extractor allows you to populate the "Frequent Users" and "Popular Tables" features within +Amundsen with the help of the TableColumnUsage class. Because this is a generic usage extractor, you need to create +a custom log parser to calculate poplarity. There is an example of how to calculate table popularity for Snowflake +in databuilder/example/scripts/sample_snowflake_table_usage.scala. +""" + +import logging +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_column_usage import ColumnReader, TableColumnUsage + +LOGGER = logging.getLogger(__name__) + + +class GenericUsageExtractor(Extractor): + # SELECT statement from table that contains usage data + SQL_STATEMENT = """ + SELECT + database, + schema, + name, + user_email, + read_count + FROM + {database}.{schema}.{table} + {where_clause_suffix}; + """ + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + POPULARITY_TABLE_DATABASE = 'popularity_table_database' + POPULARTIY_TABLE_SCHEMA = 'popularity_table_schema' + POPULARITY_TABLE_NAME = 'popularity_table_name' + DATABASE_KEY = 'database_key' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', + POPULARITY_TABLE_DATABASE: 'PROD', + POPULARTIY_TABLE_SCHEMA: 'SCHEMA', + POPULARITY_TABLE_NAME: 'TABLE', + DATABASE_KEY: 'snowflake'} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(GenericUsageExtractor.DEFAULT_CONFIG) + self._where_clause_suffix = conf.get_string(GenericUsageExtractor.WHERE_CLAUSE_SUFFIX_KEY) + self._popularity_table_database = conf.get_string(GenericUsageExtractor.POPULARITY_TABLE_DATABASE) + self._popularity_table_schema = conf.get_string(GenericUsageExtractor.POPULARTIY_TABLE_SCHEMA) + self._popularity_table_name = conf.get_string(GenericUsageExtractor.POPULARITY_TABLE_NAME) + self._database_key = conf.get_string(GenericUsageExtractor.DATABASE_KEY) + + self.sql_stmt = self.SQL_STATEMENT.format( + where_clause_suffix=self._where_clause_suffix, + database=self._popularity_table_database, + schema=self._popularity_table_schema, + table=self._popularity_table_name + ) + + LOGGER.info("SQL for popularity: {}".format(self.sql_stmt)) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableColumnUsage, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return "extractor.generic_usage" + + def _get_extract_iter(self) -> Iterator[TableColumnUsage]: + """ + Using raw level iterator, it groups to table and yields TableColumnUsage + :return: + """ + for row in self._get_raw_extract_iter(): + col_readers = [] + col_readers.append(ColumnReader(database=self._database_key, + cluster=row["database"], + schema=row["schema"], + table=row["name"], + column="*", + user_email=row["user_email"], + read_count=row["read_count"])) + yield TableColumnUsage(col_readers=col_readers) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() diff --git a/databuilder/databuilder/extractor/glue_extractor.py b/databuilder/databuilder/extractor/glue_extractor.py new file mode 100644 index 0000000000..fc9ab8b865 --- /dev/null +++ b/databuilder/databuilder/extractor/glue_extractor.py @@ -0,0 +1,124 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterator, List, Union, +) + +import boto3 +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class GlueExtractor(Extractor): + """ + Extracts tables and columns metadata from AWS Glue metastore + """ + + CLUSTER_KEY = 'cluster' + FILTER_KEY = 'filters' + MAX_RESULTS_KEY = 'max_results' + RESOURCE_SHARE_TYPE = 'resource_share_type' + REGION_NAME_KEY = "region" + PARTITION_BADGE_LABEL_KEY = "partition_badge_label" + + DEFAULT_CONFIG = ConfigFactory.from_dict({ + CLUSTER_KEY: 'gold', + FILTER_KEY: None, + MAX_RESULTS_KEY: 500, + RESOURCE_SHARE_TYPE: "ALL", + REGION_NAME_KEY: None, + PARTITION_BADGE_LABEL_KEY: None, + }) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(GlueExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(GlueExtractor.CLUSTER_KEY) + self._filters = conf.get(GlueExtractor.FILTER_KEY) + self._max_results = conf.get(GlueExtractor.MAX_RESULTS_KEY) + self._resource_share_type = conf.get(GlueExtractor.RESOURCE_SHARE_TYPE) + self._region_name = conf.get(GlueExtractor.REGION_NAME_KEY) + self._partition_badge_label = conf.get(GlueExtractor.PARTITION_BADGE_LABEL_KEY) + if self._region_name is not None: + self._glue = boto3.client('glue', region_name=self._region_name) + else: + self._glue = boto3.client('glue') + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.glue' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + It gets all tables and yields TableMetadata + :return: + """ + for row in self._get_raw_extract_iter(): + columns, i = [], 0 + + if 'StorageDescriptor' not in row: + continue + + for column in row['StorageDescriptor']['Columns']: + columns.append(ColumnMetadata( + name=column["Name"], + description=column.get("Comment"), + col_type=column["Type"], + sort_order=i, + )) + i += 1 + + for column in row.get('PartitionKeys', []): + columns.append(ColumnMetadata( + name=column["Name"], + description=column.get("Comment"), + col_type=column["Type"], + sort_order=i, + badges=[self._partition_badge_label] if self._partition_badge_label else None, + )) + i += 1 + + yield TableMetadata( + 'glue', + self._cluster, + row['DatabaseName'], + row['Name'], + row.get('Description') or row.get('Parameters', {}).get('comment'), + columns, + row.get('TableType') == 'VIRTUAL_VIEW', + ) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of results row from glue client + :return: + """ + tables = self._search_tables() + return iter(tables) + + def _search_tables(self) -> List[Dict[str, Any]]: + tables = [] + kwargs = {} + if self._filters is not None: + kwargs['Filters'] = self._filters + kwargs['MaxResults'] = self._max_results + if self._resource_share_type: + kwargs['ResourceShareType'] = self._resource_share_type + data = self._glue.search_tables(**kwargs) + tables += data['TableList'] + while 'NextToken' in data: + token = data['NextToken'] + kwargs['NextToken'] = token + data = self._glue.search_tables(**kwargs) + tables += data['TableList'] + return tables diff --git a/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py b/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py new file mode 100644 index 0000000000..4c718ded13 --- /dev/null +++ b/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py @@ -0,0 +1,370 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import time +from datetime import datetime +from functools import wraps +from multiprocessing.pool import ThreadPool +from typing import ( + Any, Iterator, List, Union, +) + +from pyhocon import ConfigFactory, ConfigTree +from pytz import UTC + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.filesystem.filesystem import FileSystem, is_client_side_error +from databuilder.models.table_last_updated import TableLastUpdated + +LOGGER = logging.getLogger(__name__) +OLDEST_TIMESTAMP = datetime.fromtimestamp(0, UTC) + + +def fs_error_handler(f: Any) -> Any: + """ + A Decorator that handles error from FileSystem for HiveTableLastUpdatedExtractor use case + If it's client side error, it logs in INFO level, and other errors is logged as error level with stacktrace. + The decorator is intentionally not re-raising exception so that it can isolate the error. + :param f: + :return: + """ + + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return f(*args, **kwargs) + except Exception as e: + if is_client_side_error(e): + LOGGER.info('Invalid metadata. Skipping. args: %s, kwargs: %s. error: %s', args, kwargs, e) + return None + else: + LOGGER.exception('Unknown exception while processing args: %s, kwargs: %s', args, kwargs) + return None + + return wrapper + + +class HiveTableLastUpdatedExtractor(Extractor): + """ + Uses Hive metastore and underlying storage to figure out last updated timestamp of table. + + It turned out that we cannot use table param "last_modified_time", as it provides DDL date not DML date. + For this reason, we are utilizing underlying file of Hive to fetch latest updated date. + However, it is not efficient to poke all files in Hive, and it only poke underlying storage for non-partitioned + table. For partitioned table, it will fetch partition created timestamp, and it's close enough for last updated + timestamp. + + """ + DEFAULT_PARTITION_TABLE_SQL_STATEMENT = """ + SELECT + DBS.NAME as `schema`, + TBL_NAME as table_name, + MAX(PARTITIONS.CREATE_TIME) as last_updated_time + FROM TBLS + JOIN DBS ON TBLS.DB_ID = DBS.DB_ID + JOIN PARTITIONS ON TBLS.TBL_ID = PARTITIONS.TBL_ID + {where_clause_suffix} + GROUP BY `schema`, table_name + ORDER BY `schema`, table_name; + """ + + DEFAULT_POSTGRES_PARTITION_TABLE_SQL_STATEMENT = """ + SELECT + d."NAME" as "schema", + t."TBL_NAME" as table_name, + MAX(p."CREATE_TIME") as last_updated_time + FROM "TBLS" t + JOIN "DBS" d ON t."DB_ID" = d."DB_ID" + JOIN "PARTITIONS" p ON t."TBL_ID" = p."TBL_ID" + {where_clause_suffix} + GROUP BY "schema", table_name + ORDER BY "schema", table_name; + """ + + DEFAULT_NON_PARTITIONED_TABLE_SQL_STATEMENT = """ + SELECT + DBS.NAME as `schema`, + TBL_NAME as table_name, + SDS.LOCATION as location + FROM TBLS + JOIN DBS ON TBLS.DB_ID = DBS.DB_ID + JOIN SDS ON TBLS.SD_ID = SDS.SD_ID + {where_clause_suffix} + ORDER BY `schema`, table_name; + """ + + DEFAULT_POSTGRES_NON_PARTITIONED_TABLE_SQL_STATEMENT = """ + SELECT + d."NAME" as "schema", + t."TBL_NAME" as table_name, + s."LOCATION" as location + FROM "TBLS" t + JOIN "DBS" d ON t."DB_ID" = d."DB_ID" + JOIN "SDS" s ON t."SD_ID" = s."SD_ID" + {where_clause_suffix} + ORDER BY "schema", table_name; + """ + + # Additional where clause for non partitioned table SQL + DEFAULT_ADDTIONAL_WHERE_CLAUSE = """ NOT EXISTS (SELECT * FROM PARTITIONS WHERE PARTITIONS.TBL_ID = TBLS.TBL_ID) + AND NOT EXISTS (SELECT * FROM PARTITION_KEYS WHERE PARTITION_KEYS.TBL_ID = TBLS.TBL_ID) + """ + + DEFAULT_POSTGRES_ADDTIONAL_WHERE_CLAUSE = """ NOT EXISTS (SELECT * FROM "PARTITIONS" p + WHERE p."TBL_ID" = t."TBL_ID") AND NOT EXISTS (SELECT * FROM "PARTITION_KEYS" pk WHERE pk."TBL_ID" = t."TBL_ID") + """ + + DATABASE = 'hive' + + # CONFIG KEYS + PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY = 'partitioned_table_where_clause_suffix' + NON_PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY = 'non_partitioned_table_where_clause_suffix' + CLUSTER_KEY = 'cluster' + # number of threads that fetches metadata from FileSystem + FS_WORKER_POOL_SIZE = 'fs_worker_pool_size' + FS_WORKER_TIMEOUT_SEC = 'fs_worker_timeout_sec' + # If number of files that it needs to fetch metadata is larger than this threshold, it will skip the table. + FILE_CHECK_THRESHOLD = 'file_check_threshold' + + DEFAULT_CONFIG = ConfigFactory.from_dict({PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY: ' ', + NON_PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY: ' ', + CLUSTER_KEY: 'gold', + FS_WORKER_POOL_SIZE: 500, + FS_WORKER_TIMEOUT_SEC: 60, + FILE_CHECK_THRESHOLD: -1}) + + def init(self, conf: ConfigTree) -> None: + self._conf = conf.with_fallback(HiveTableLastUpdatedExtractor.DEFAULT_CONFIG) + + pool_size = self._conf.get_int(HiveTableLastUpdatedExtractor.FS_WORKER_POOL_SIZE) + LOGGER.info('Using thread pool size: %s', pool_size) + self._fs_worker_pool = ThreadPool(processes=pool_size) + self._fs_worker_timeout = self._conf.get_int(HiveTableLastUpdatedExtractor.FS_WORKER_TIMEOUT_SEC) + LOGGER.info('Using thread timeout: %s seconds', self._fs_worker_timeout) + + self._cluster = self._conf.get_string(HiveTableLastUpdatedExtractor.CLUSTER_KEY) + + self._partitioned_table_extractor = self._get_partitioned_table_sql_alchemy_extractor() + self._non_partitioned_table_extractor = self._get_non_partitioned_table_sql_alchemy_extractor() + self._fs = self._get_filesystem() + self._last_updated_filecheck_threshold \ + = self._conf.get_int(HiveTableLastUpdatedExtractor.FILE_CHECK_THRESHOLD) + + self._extract_iter: Union[None, Iterator] = None + + def _get_partitioned_table_sql_alchemy_extractor(self) -> Extractor: + """ + Getting an SQLAlchemy extractor that extracts last updated timestamp for partitioned table. + :return: SQLAlchemyExtractor + """ + + sql_stmt = self._choose_default_partitioned_sql_stm().format( + where_clause_suffix=self._conf.get_string( + self.PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY, ' ')) + + LOGGER.info('SQL for partitioned table against Hive metastore: %s', sql_stmt) + + sql_alchemy_extractor = SQLAlchemyExtractor() + sql_alchemy_conf = Scoped.get_scoped_conf(self._conf, sql_alchemy_extractor.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: sql_stmt})) + sql_alchemy_extractor.init(sql_alchemy_conf) + return sql_alchemy_extractor + + def _choose_default_partitioned_sql_stm(self) -> str: + conn_string = self._conf.get_string("extractor.sqlalchemy.conn_string") + if conn_string.startswith('postgres') or conn_string.startswith('postgresql'): + return self.DEFAULT_POSTGRES_PARTITION_TABLE_SQL_STATEMENT + else: + return self.DEFAULT_PARTITION_TABLE_SQL_STATEMENT + + def _get_non_partitioned_table_sql_alchemy_extractor(self) -> Extractor: + """ + Getting an SQLAlchemy extractor that extracts storage location for non-partitioned table for further probing + last updated timestamp + + :return: SQLAlchemyExtractor + """ + + default_sql_stmt, default_additional_where_clause = self._choose_default_non_partitioned_sql_stm() + + if self.NON_PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY in self._conf: + where_clause_suffix = """ + {} + AND {} + """.format(self._conf.get_string( + self.NON_PARTITIONED_TABLE_WHERE_CLAUSE_SUFFIX_KEY), + default_additional_where_clause) + else: + where_clause_suffix = 'WHERE {}'.format(default_additional_where_clause) + + sql_stmt = default_sql_stmt.format( + where_clause_suffix=where_clause_suffix) + + LOGGER.info('SQL for non-partitioned table against Hive metastore: %s', sql_stmt) + + sql_alchemy_extractor = SQLAlchemyExtractor() + sql_alchemy_conf = Scoped.get_scoped_conf(self._conf, sql_alchemy_extractor.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: sql_stmt})) + sql_alchemy_extractor.init(sql_alchemy_conf) + return sql_alchemy_extractor + + def _choose_default_non_partitioned_sql_stm(self) -> List[str]: + conn_string = self._conf.get_string("extractor.sqlalchemy.conn_string") + if conn_string.startswith('postgres') or conn_string.startswith('postgresql'): + return [self.DEFAULT_POSTGRES_NON_PARTITIONED_TABLE_SQL_STATEMENT, + self.DEFAULT_POSTGRES_ADDTIONAL_WHERE_CLAUSE] + else: + return [self.DEFAULT_NON_PARTITIONED_TABLE_SQL_STATEMENT, self.DEFAULT_ADDTIONAL_WHERE_CLAUSE] + + def _get_filesystem(self) -> FileSystem: + fs = FileSystem() + fs.init(Scoped.get_scoped_conf(self._conf, fs.get_scope())) + return fs + + def extract(self) -> Union[TableLastUpdated, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.hive_table_last_updated' + + def _get_extract_iter(self) -> Iterator[TableLastUpdated]: + """ + An iterator that utilizes Generator pattern. First it provides TableLastUpdated objects for partitioned + table, straight from partitioned_table_extractor (SQLAlchemyExtractor) + + Once partitioned table is done, it uses non_partitioned_table_extractor to get storage location of table, + and probing files under storage location to get max timestamp per table. + :return: + """ + + partitioned_tbl_row = self._partitioned_table_extractor.extract() + while partitioned_tbl_row: + yield TableLastUpdated(table_name=partitioned_tbl_row['table_name'], + last_updated_time_epoch=partitioned_tbl_row['last_updated_time'], + schema=partitioned_tbl_row['schema'], + db=HiveTableLastUpdatedExtractor.DATABASE, + cluster=self._cluster) + partitioned_tbl_row = self._partitioned_table_extractor.extract() + + LOGGER.info('Extracting non-partitioned table') + count = 0 + non_partitioned_tbl_row = self._non_partitioned_table_extractor.extract() + while non_partitioned_tbl_row: + count += 1 + if count % 10 == 0: + LOGGER.info('Processed %i non-partitioned tables', count) + + if not non_partitioned_tbl_row['location']: + LOGGER.warning('Skipping as no storage location available. %s', non_partitioned_tbl_row) + non_partitioned_tbl_row = self._non_partitioned_table_extractor.extract() + continue + + start = time.time() + table_last_updated = self._get_last_updated_datetime_from_filesystem( + table=non_partitioned_tbl_row['table_name'], + schema=non_partitioned_tbl_row['schema'], + storage_location=non_partitioned_tbl_row['location']) + LOGGER.info('Elapsed: %i seconds', time.time() - start) + + if table_last_updated: + yield table_last_updated + + non_partitioned_tbl_row = self._non_partitioned_table_extractor.extract() + + def _get_last_updated_datetime_from_filesystem(self, + table: str, + schema: str, + storage_location: str, + ) -> Union[TableLastUpdated, None]: + """ + Fetching metadata within files under storage_location to get latest timestamp. + (First level only under storage_location) + Utilizes thread pool to enhance performance. Not using processpool, as it's almost entirely IO bound operation. + + :param table: + :param schema: + :param storage_location: + :return: + """ + + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug(f'Getting last updated datetime for {schema}.{table} in {storage_location}') + + last_updated = OLDEST_TIMESTAMP + + paths = self._ls(storage_location) + if not paths: + LOGGER.info(f'{schema}.{table} does not have any file in path {storage_location}. Skipping') + return None + + LOGGER.info(f'Fetching metadata for {schema}.{table} of {len(paths)} files') + + if 0 < self._last_updated_filecheck_threshold < len(paths): + LOGGER.info(f'Skipping {schema}.{table} due to too many files. ' + f'{len(paths)} files exist in {storage_location}') + return None + + time_stamp_futures = \ + [self._fs_worker_pool.apply_async(self._get_timestamp, (path, schema, table, storage_location)) + for path in paths] + for time_stamp_future in time_stamp_futures: + try: + time_stamp = time_stamp_future.get(timeout=self._fs_worker_timeout) + if time_stamp: + last_updated = max(time_stamp, last_updated) + except TimeoutError: + LOGGER.warning('Timed out on paths %s . Skipping', paths) + + if last_updated == OLDEST_TIMESTAMP: + LOGGER.info(f'No timestamp was derived on {schema}.{table} from location: {storage_location} . Skipping') + return None + + result = TableLastUpdated(table_name=table, + last_updated_time_epoch=int((last_updated - OLDEST_TIMESTAMP).total_seconds()), + schema=schema, + db=HiveTableLastUpdatedExtractor.DATABASE, + cluster=self._cluster) + + return result + + @fs_error_handler + def _ls(self, path: str) -> List[str]: + """ + An wrapper to FileSystem.ls to use fs_error_handler decorator + :param path: + :return: + """ + return self._fs.ls(path) + + @fs_error_handler + def _get_timestamp(self, + path: str, + schema: str, + table: str, + storage_location: str, + ) -> Union[datetime, None]: + """ + An wrapper to FileSystem.ls to use fs_error_handler decorator + :param path: + :param schema: + :param table: + :param storage_location: + :return: + """ + if not path: + LOGGER.info(f'Empty path {path} on {schema}.{table} in storage location {storage_location} . Skipping') + return None + + if not self._fs.is_file(path): + return None + + file_metadata = self._fs.info(path) + return file_metadata.last_updated diff --git a/databuilder/databuilder/extractor/hive_table_metadata_extractor.py b/databuilder/databuilder/extractor/hive_table_metadata_extractor.py new file mode 100644 index 0000000000..629e39f61a --- /dev/null +++ b/databuilder/databuilder/extractor/hive_table_metadata_extractor.py @@ -0,0 +1,179 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree +from sqlalchemy.engine.url import make_url + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.extractor.table_metadata_constants import PARTITION_BADGE +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class HiveTableMetadataExtractor(Extractor): + """ + Extracts Hive table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + EXTRACT_SQL = 'extract_sql' + # SELECT statement from hive metastore database to extract table and column metadata + # Below SELECT statement uses UNION to combining two queries together. + # 1st query is retrieving partition columns + # 2nd query is retrieving columns + # Using UNION to combine above two statements and order by table & partition identifier. + DEFAULT_SQL_STATEMENT = """ + SELECT source.* FROM + (SELECT t.TBL_ID, d.NAME as `schema`, t.TBL_NAME name, t.TBL_TYPE, tp.PARAM_VALUE as description, + p.PKEY_NAME as col_name, p.INTEGER_IDX as col_sort_order, + p.PKEY_TYPE as col_type, p.PKEY_COMMENT as col_description, 1 as "is_partition_col", + IF(t.TBL_TYPE = 'VIRTUAL_VIEW', 1, 0) "is_view" + FROM TBLS t + JOIN DBS d ON t.DB_ID = d.DB_ID + JOIN PARTITION_KEYS p ON t.TBL_ID = p.TBL_ID + LEFT JOIN TABLE_PARAMS tp ON (t.TBL_ID = tp.TBL_ID AND tp.PARAM_KEY='comment') + {where_clause_suffix} + UNION + SELECT t.TBL_ID, d.NAME as `schema`, t.TBL_NAME name, t.TBL_TYPE, tp.PARAM_VALUE as description, + c.COLUMN_NAME as col_name, c.INTEGER_IDX as col_sort_order, + c.TYPE_NAME as col_type, c.COMMENT as col_description, 0 as "is_partition_col", + IF(t.TBL_TYPE = 'VIRTUAL_VIEW', 1, 0) "is_view" + FROM TBLS t + JOIN DBS d ON t.DB_ID = d.DB_ID + JOIN SDS s ON t.SD_ID = s.SD_ID + JOIN COLUMNS_V2 c ON s.CD_ID = c.CD_ID + LEFT JOIN TABLE_PARAMS tp ON (t.TBL_ID = tp.TBL_ID AND tp.PARAM_KEY='comment') + {where_clause_suffix} + ) source + ORDER by tbl_id, is_partition_col desc; + """ + + DEFAULT_POSTGRES_SQL_STATEMENT = """ + SELECT source.* FROM + (SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE", + tp."PARAM_VALUE" as description, p."PKEY_NAME" as col_name, p."INTEGER_IDX" as col_sort_order, + p."PKEY_TYPE" as col_type, p."PKEY_COMMENT" as col_description, 1 as "is_partition_col", + CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1 + ELSE 0 END as "is_view" + FROM "TBLS" t + JOIN "DBS" d ON t."DB_ID" = d."DB_ID" + JOIN "PARTITION_KEYS" p ON t."TBL_ID" = p."TBL_ID" + LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment') + {where_clause_suffix} + UNION + SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE", + tp."PARAM_VALUE" as description, c."COLUMN_NAME" as col_name, c."INTEGER_IDX" as col_sort_order, + c."TYPE_NAME" as col_type, c."COMMENT" as col_description, 0 as "is_partition_col", + CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1 + ELSE 0 END as "is_view" + FROM "TBLS" t + JOIN "DBS" d ON t."DB_ID" = d."DB_ID" + JOIN "SDS" s ON t."SD_ID" = s."SD_ID" + JOIN "COLUMNS_V2" c ON s."CD_ID" = c."CD_ID" + LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment') + {where_clause_suffix} + ) source + ORDER by tbl_id, is_partition_col desc; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster' + + DEFAULT_CONFIG = ConfigFactory.from_dict({WHERE_CLAUSE_SUFFIX_KEY: ' ', + CLUSTER_KEY: 'gold'}) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(HiveTableMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(HiveTableMetadataExtractor.CLUSTER_KEY) + + self._alchemy_extractor = SQLAlchemyExtractor() + + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) + default_sql = self._choose_default_sql_stm(sql_alch_conf).format( + where_clause_suffix=conf.get_string(HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY)) + + self.sql_stmt = conf.get_string(HiveTableMetadataExtractor.EXTRACT_SQL, default=default_sql) + + LOGGER.info('SQL for hive metastore: %s', self.sql_stmt) + + sql_alch_conf = sql_alch_conf.with_fallback(ConfigFactory.from_dict( + {SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def _choose_default_sql_stm(self, conf: ConfigTree) -> str: + url = make_url(conf.get_string(SQLAlchemyExtractor.CONN_STRING)) + if url.drivername.lower() in ['postgresql', 'postgres']: + return self.DEFAULT_POSTGRES_SQL_STATEMENT + else: + return self.DEFAULT_SQL_STATEMENT + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.hive_table_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + column = None + if row['is_partition_col'] == 1: + # create add a badge to indicate partition column + column = ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order'], [PARTITION_BADGE]) + else: + column = ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order']) + columns.append(column) + is_view = last_row['is_view'] == 1 + yield TableMetadata('hive', self._cluster, + last_row['schema'], + last_row['name'], + last_row['description'], + columns, + is_view=is_view) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/kafka_schema_registry_extractor.py b/databuilder/databuilder/extractor/kafka_schema_registry_extractor.py new file mode 100644 index 0000000000..b6d29aee8b --- /dev/null +++ b/databuilder/databuilder/extractor/kafka_schema_registry_extractor.py @@ -0,0 +1,183 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import logging +from asyncio.log import logger +from typing import ( + Any, Dict, Iterator, List, Optional, Union, +) + +from pyhocon import ConfigTree +from schema_registry.client import Auth, SchemaRegistryClient +from schema_registry.client.utils import SchemaVersion + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +LOGGER = logging.getLogger(__name__) + + +class KafkaSchemaRegistryExtractor(Extractor): + """ + Extracts the latest version of all schemas from a given + Kafka Schema Registry URL + """ + + REGISTRY_URL_KEY = "registry_url" + REGISTRY_USERNAME_KEY = "registry_username" + REGISTRY_PASSWORD_KEY = "registry_password" + + def init(self, conf: ConfigTree) -> None: + self._registry_base_url = conf.get( + KafkaSchemaRegistryExtractor.REGISTRY_URL_KEY + ) + + self._registry_username = conf.get( + KafkaSchemaRegistryExtractor.REGISTRY_USERNAME_KEY, None + ) + + self._registry_password = conf.get( + KafkaSchemaRegistryExtractor.REGISTRY_PASSWORD_KEY, None + ) + + # Add authentication if user and password are provided + if all((self._registry_username, self._registry_password)): + self._client = SchemaRegistryClient( + url=self._registry_base_url, + auth=Auth( + username=self._registry_username, + password=self._registry_password + ) + ) + else: + self._client = SchemaRegistryClient( + url=self._registry_base_url, + ) + + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + if self._extract_iter is None: + return None + try: + return next(self._extract_iter) + except StopIteration: + return None + except Exception as e: + logger.error(f'Failed to generate next table: {e}') + return None + + def get_scope(self) -> str: + return 'extractor.kafka_schema_registry' + + def _get_extract_iter(self) -> Optional[Iterator[TableMetadata]]: + """ + Return an iterator generating TableMetadata for all of the schemas. + """ + for schema_version in self._get_raw_extract_iter(): + subject = schema_version.subject + schema = schema_version.schema.raw_schema + LOGGER.info((f'Subject: {subject}, ' + f'Schema: {schema}')) + + try: + yield KafkaSchemaRegistryExtractor._create_table( + schema=schema, + subject_name=subject, + cluster_name=schema.get( + 'namespace', 'kafka-schema-registry' + ), + schema_name=schema.get('name', ''), + schema_description=schema.get('doc', None), + ) + except Exception as e: + logger.warning(f'Failed to generate table for {subject}: {e}') + continue + + def _get_raw_extract_iter(self) -> Iterator[SchemaVersion]: + """ + Return iterator of results row from schema registry + """ + subjects = self._client.get_subjects() + + LOGGER.info(f'Number of extracted subjects: {len(subjects)}') + LOGGER.info(f'Extracted subjects: {subjects}') + + for subj in subjects: + subj_schema = self._client.get_schema(subj) + LOGGER.info(f'Subject <{subj}> max version: {subj_schema.version}') + + yield subj_schema + + @staticmethod + def _create_table( + schema: Dict[str, Any], + subject_name: str, + cluster_name: str, + schema_name: str, + schema_description: str, + ) -> Optional[TableMetadata]: + """ + Create TableMetadata based on given schema and names + """ + columns: List[ColumnMetadata] = [] + + for i, field in enumerate(schema['fields']): + columns.append( + ColumnMetadata( + name=field['name'], + description=field.get('doc', None), + col_type=KafkaSchemaRegistryExtractor._get_property_type( + field + ), + sort_order=i, + ) + ) + + return TableMetadata( + database='kafka_schema_registry', + cluster=cluster_name, + schema=subject_name, + name=schema_name, + description=schema_description, + columns=columns, + ) + + @staticmethod + def _get_property_type(schema: Dict) -> str: + """ + Return type of the given schema. + It will also works for nested schema types. + """ + if 'type' not in schema: + return 'object' + + if type(schema['type']) is dict: + return KafkaSchemaRegistryExtractor._get_property_type( + schema['type'] + ) + + # If schema can have multiple types + if type(schema['type']) is list: + return '|'.join(schema['type']) + + if schema['type'] == 'record': + properties = [ + f"{field['name']}:" + f"{KafkaSchemaRegistryExtractor._get_property_type(field)}" + for field in schema.get('fields', {}) + ] + if len(properties) > 0: + if 'name' in schema: + return schema['name'] + \ + ':struct<' + ','.join(properties) + '>' + return 'struct<' + ','.join(properties) + '>' + return 'struct' + elif schema['type'] == 'array': + items = KafkaSchemaRegistryExtractor._get_property_type( + schema.get("items", {}) + ) + return 'array<' + items + '>' + else: + return schema['type'] diff --git a/databuilder/databuilder/extractor/kafka_source_extractor.py b/databuilder/databuilder/extractor/kafka_source_extractor.py new file mode 100644 index 0000000000..99600cd8d2 --- /dev/null +++ b/databuilder/databuilder/extractor/kafka_source_extractor.py @@ -0,0 +1,172 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from datetime import datetime, timedelta +from typing import Any + +from confluent_kafka import ( + Consumer, KafkaError, KafkaException, +) +from pyhocon import ConfigTree + +from databuilder import Scoped +from databuilder.callback.call_back import Callback +from databuilder.extractor.base_extractor import Extractor +from databuilder.transformer.base_transformer import Transformer + +LOGGER = logging.getLogger(__name__) + + +class KafkaSourceExtractor(Extractor, Callback): + """ + Kafka source extractor. The extractor itself is single consumer(single-threaded) + which could consume all the partitions given a topic or a subset of partitions. + + It uses the "micro-batch" concept to ingest data from a given Kafka topic and + persist into downstream sink. + Once the publisher commit successfully, it will trigger the extractor's callback to commit the + consumer offset. + """ + # The dict of Kafka consumer config + CONSUMER_CONFIG = 'consumer_config' + # The consumer group id. Ideally each Kafka extractor should only associate with one consumer group. + CONSUMER_GROUP_ID = 'group.id' + # We don't deserde the key of the message. + # CONSUMER_VALUE_DESERDE = 'value.deserializer' + # Each Kafka extractor should only consume one single topic. We could extend to consume more topic if needed. + TOPIC_NAME_LIST = 'topic_name_list' + + # Time out config. It will abort from reading the Kafka topic after timeout is reached. Unit is seconds + CONSUMER_TOTAL_TIMEOUT_SEC = 'consumer_total_timeout_sec' + + # The timeout for consumer polling messages. Default to 1 sec + CONSUMER_POLL_TIMEOUT_SEC = 'consumer_poll_timeout_sec' + + # Config on whether we throw exception if transformation fails + TRANSFORMER_THROWN_EXCEPTION = 'transformer_thrown_exception' + + # The value transformer to deserde the Kafka message + RAW_VALUE_TRANSFORMER = 'raw_value_transformer' + + def init(self, conf: ConfigTree) -> None: + self.conf = conf + self.consumer_config = conf.get_config(KafkaSourceExtractor.CONSUMER_CONFIG).\ + as_plain_ordered_dict() + + self.topic_names: list = conf.get_list(KafkaSourceExtractor.TOPIC_NAME_LIST) + + if not self.topic_names: + raise Exception('Kafka topic needs to be provided by the user.') + + self.consumer_total_timeout = conf.get_int(KafkaSourceExtractor.CONSUMER_TOTAL_TIMEOUT_SEC, + default=10) + + self.consumer_poll_timeout = conf.get_int(KafkaSourceExtractor.CONSUMER_POLL_TIMEOUT_SEC, + default=1) + + self.transformer_thrown_exception = conf.get_bool(KafkaSourceExtractor.TRANSFORMER_THROWN_EXCEPTION, + default=False) + + # Transform the protoBuf message with a transformer + val_transformer = conf.get(KafkaSourceExtractor.RAW_VALUE_TRANSFORMER) + if val_transformer is None: + raise Exception('A message transformer should be provided.') + else: + try: + module_name, class_name = val_transformer.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.transformer = getattr(mod, class_name)() + except Exception: + raise RuntimeError('The Kafka message value deserde class cant instantiated!') + + if not isinstance(self.transformer, Transformer): + raise Exception('The transformer needs to be subclass of the base transformer') + self.transformer.init(Scoped.get_scoped_conf(conf, self.transformer.get_scope())) + + # Consumer init + try: + # Disable enable.auto.commit + self.consumer_config['enable.auto.commit'] = False + + self.consumer = Consumer(self.consumer_config) + # TODO: to support only consume a subset of partitions. + self.consumer.subscribe(self.topic_names) + except Exception: + raise RuntimeError('Consumer could not start correctly!') + + def extract(self) -> Any: + """ + :return: Provides a record or None if no more to extract + """ + records = self.consume() + for record in records: + try: + transform_record = self.transformer.transform(record=record) + yield transform_record + except Exception as e: + # Has issues tranform / deserde the record. drop the record in default + LOGGER.exception(e) + if self.transformer_thrown_exception: + # if config enabled, it will throw exception. + # Users need to figure out how to rewind the consumer offset + raise Exception('Encounter exception when transform the record') + + def on_success(self) -> None: + """ + Commit the offset + once: + 1. get the success callback from publisher in + https://github.com/amundsen-io/amundsendatabuilder/blob/ + master/databuilder/publisher/base_publisher.py#L50 + 2. close the consumer. + + :return: + """ + # set enable.auto.commit to False to avoid auto commit offset + if self.consumer: + self.consumer.commit(asynchronous=False) + self.consumer.close() + + def on_failure(self) -> None: + if self.consumer: + self.consumer.close() + + def consume(self) -> Any: + """ + Consume messages from a give list of topic + + :return: + """ + records = [] + start = datetime.now() + try: + while True: + msg = self.consumer.poll(timeout=self.consumer_poll_timeout) + end = datetime.now() + + # The consumer exceeds consume timeout + if (end - start) > timedelta(seconds=self.consumer_total_timeout): + # Exceed the consume timeout + break + + if msg is None: + continue + + if msg.error(): + # Hit the EOF of partition + if msg.error().code() == KafkaError._PARTITION_EOF: + continue + else: + raise KafkaException(msg.error()) + else: + records.append(msg.value()) + + except Exception as e: + LOGGER.exception(e) + finally: + return records + + def get_scope(self) -> str: + return 'extractor.kafka_source' diff --git a/databuilder/databuilder/extractor/mssql_metadata_extractor.py b/databuilder/databuilder/extractor/mssql_metadata_extractor.py new file mode 100644 index 0000000000..ce2d66e6a9 --- /dev/null +++ b/databuilder/databuilder/extractor/mssql_metadata_extractor.py @@ -0,0 +1,176 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema_name', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class MSSQLMetadataExtractor(Extractor): + """ + Extracts Microsoft SQL Server table and column metadata from underlying + meta store database using SQLAlchemyExtractor + """ + + # SELECT statement from MS SQL to extract table and column metadata + SQL_STATEMENT = """ + SELECT DISTINCT + {cluster_source} AS cluster, + TBL.TABLE_SCHEMA AS [schema_name], + TBL.TABLE_NAME AS [name], + CAST(PROP.VALUE AS NVARCHAR(MAX)) AS [description], + COL.COLUMN_NAME AS [col_name], + COL.DATA_TYPE AS [col_type], + CAST(PROP_COL.VALUE AS NVARCHAR(MAX)) AS [col_description], + COL.ORDINAL_POSITION AS col_sort_order + FROM INFORMATION_SCHEMA.TABLES TBL + INNER JOIN INFORMATION_SCHEMA.COLUMNS COL + ON (COL.TABLE_NAME = TBL.TABLE_NAME + AND COL.TABLE_SCHEMA = TBL.TABLE_SCHEMA ) + LEFT JOIN SYS.EXTENDED_PROPERTIES PROP + ON (PROP.MAJOR_ID = OBJECT_ID(TBL.TABLE_SCHEMA + '.' + TBL.TABLE_NAME) + AND PROP.MINOR_ID = 0 + AND PROP.NAME = 'MS_Description') + LEFT JOIN SYS.EXTENDED_PROPERTIES PROP_COL + ON (PROP_COL.MAJOR_ID = OBJECT_ID(TBL.TABLE_SCHEMA + '.' + TBL.TABLE_NAME) + AND PROP_COL.MINOR_ID = COL.ORDINAL_POSITION + AND PROP_COL.NAME = 'MS_Description') + WHERE TBL.TABLE_TYPE = 'base table' {where_clause_suffix} + ORDER BY + CLUSTER, + SCHEMA_NAME, + NAME, + COL_SORT_ORDER + ; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'DB_NAME()' + + DEFAULT_CONFIG = ConfigFactory.from_dict({ + WHERE_CLAUSE_SUFFIX_KEY: '', + CLUSTER_KEY: DEFAULT_CLUSTER_NAME, + USE_CATALOG_AS_CLUSTER_NAME: True} + ) + + DEFAULT_WHERE_CLAUSE_VALUE = 'and tbl.table_schema in {schemas}' + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(MSSQLMetadataExtractor.DEFAULT_CONFIG) + + self._cluster = conf.get_string(MSSQLMetadataExtractor.CLUSTER_KEY) + + if conf.get_bool(MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): + cluster_source = "DB_NAME()" + else: + cluster_source = f"'{self._cluster}'" + + self._database = conf.get_string( + MSSQLMetadataExtractor.DATABASE_KEY, + default='mssql') + + config_where_clause = conf.get_string( + MSSQLMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY) + + LOGGER.info("Crawling for Schemas %s", config_where_clause) + + if config_where_clause: + where_clause_suffix = MSSQLMetadataExtractor \ + .DEFAULT_WHERE_CLAUSE_VALUE \ + .format(schemas=config_where_clause) + else: + where_clause_suffix = '' + + self.sql_stmt = MSSQLMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=where_clause_suffix, + cluster_source=cluster_source + ) + + LOGGER.info('SQL for MS SQL Metadata: %s', self.sql_stmt) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.mssql_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, + it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append( + ColumnMetadata( + row['col_name'], + row['col_description'], + row['col_type'], + row['col_sort_order'])) + + yield TableMetadata( + self._database, + last_row['cluster'], + last_row['schema_name'], + last_row['name'], + last_row['description'], + columns, + tags=last_row['schema_name']) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey( + schema_name=row['schema_name'], + table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/mysql_metadata_extractor.py b/databuilder/databuilder/extractor/mysql_metadata_extractor.py new file mode 100644 index 0000000000..ab58a7f31d --- /dev/null +++ b/databuilder/databuilder/extractor/mysql_metadata_extractor.py @@ -0,0 +1,139 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class MysqlMetadataExtractor(Extractor): + """ + Extracts mysql table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + # SELECT statement from mysql information_schema to extract table and column metadata + SQL_STATEMENT = """ + SELECT + lower(c.column_name) AS col_name, + c.column_comment AS col_description, + lower(c.data_type) AS col_type, + lower(c.ordinal_position) AS col_sort_order, + {cluster_source} AS cluster, + lower(c.table_schema) AS "schema", + lower(c.table_name) AS name, + t.table_comment AS description, + case when lower(t.table_type) = "view" then "true" else "false" end AS is_view + FROM + INFORMATION_SCHEMA.COLUMNS AS c + LEFT JOIN + INFORMATION_SCHEMA.TABLES t + ON c.TABLE_NAME = t.TABLE_NAME + AND c.TABLE_SCHEMA = t.TABLE_SCHEMA + {where_clause_suffix} + ORDER by cluster, "schema", name, col_sort_order ; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', CLUSTER_KEY: DEFAULT_CLUSTER_NAME, USE_CATALOG_AS_CLUSTER_NAME: True} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(MysqlMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(MysqlMetadataExtractor.CLUSTER_KEY) + + if conf.get_bool(MysqlMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): + cluster_source = "c.table_catalog" + else: + cluster_source = f"'{self._cluster}'" + + self._database = conf.get_string(MysqlMetadataExtractor.DATABASE_KEY, default='mysql') + + self.sql_stmt = MysqlMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(MysqlMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + cluster_source=cluster_source + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info('SQL for mysql metadata: %s', self.sql_stmt) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.mysql_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order'])) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + last_row['description'], + columns, + is_view=last_row['is_view']) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/mysql_search_data_extractor.py b/databuilder/databuilder/extractor/mysql_search_data_extractor.py new file mode 100644 index 0000000000..d98138a632 --- /dev/null +++ b/databuilder/databuilder/extractor/mysql_search_data_extractor.py @@ -0,0 +1,539 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from typing import ( + Any, Callable, Dict, Iterator, List, Optional, +) + +from amundsen_rds.models.badge import Badge +from amundsen_rds.models.cluster import Cluster +from amundsen_rds.models.column import ColumnDescription, TableColumn +from amundsen_rds.models.dashboard import ( + Dashboard, DashboardChart, DashboardCluster, DashboardDescription, DashboardExecution, DashboardFollower, + DashboardGroup, DashboardGroupDescription, DashboardOwner, DashboardQuery, DashboardUsage, +) +from amundsen_rds.models.database import Database +from amundsen_rds.models.schema import Schema, SchemaDescription +from amundsen_rds.models.table import ( + Table, TableDescription, TableFollower, TableOwner, TableProgrammaticDescription, TableTimestamp, TableUsage, +) +from amundsen_rds.models.tag import Tag +from amundsen_rds.models.user import User +from pyhocon import ConfigTree +from sqlalchemy import create_engine, func +from sqlalchemy.orm import ( + Session, load_only, sessionmaker, subqueryload, +) + +from databuilder.extractor.base_extractor import Extractor + +LOGGER = logging.getLogger(__name__) + + +def _table_search_query(session: Session, table_filter: List, offset: int, limit: int) -> List: + """ + Table query + :param session: + :param table_filter: + :param offset: + :param limit: + :return: + """ + # table + query = session.query(Table).filter(*table_filter).options( + load_only(Table.rk, Table.name, Table.schema_rk) + ) + + # description + query = query.options( + subqueryload(Table.description).options( + load_only(TableDescription.description) + ) + ).options( + subqueryload(Table.programmatic_descriptions).options( + load_only(TableProgrammaticDescription.description) + ) + ) + + # schema, cluster, database + query = query.options( + subqueryload(Table.schema).options( + load_only(Schema.name, Schema.cluster_rk), + subqueryload(Schema.description).options( + load_only(SchemaDescription.description) + ), + subqueryload(Schema.cluster).options( + load_only(Cluster.name, Cluster.database_rk), + subqueryload(Cluster.database).options( + load_only(Database.name) + ) + ) + ) + ) + + # column + query = query.options( + subqueryload(Table.columns).options( + load_only(TableColumn.rk, TableColumn.name), + subqueryload(TableColumn.description).options( + load_only(ColumnDescription.description) + ) + ) + ) + + # tag, badge + query = query.options( + subqueryload(Table.tags).options( + load_only(Tag.rk, Tag.tag_type) + ) + ).options( + subqueryload(Table.badges).options( + load_only(Badge.rk) + ) + ) + + # usage + query = query.options( + subqueryload(Table.usage).options( + load_only(TableUsage.read_count) + ) + ) + + # timestamp + query = query.options( + subqueryload(Table.timestamp).options( + load_only(TableTimestamp.last_updated_timestamp) + ) + ) + + query = query.order_by(Table.rk).offset(offset).limit(limit) + + return query.all() + + +def _table_search(session: Session, published_tag: str, limit: int) -> List[Dict]: + """ + Query table metadata. + :param session: + :param published_tag: + :param limit: + :return: + """ + LOGGER.info('Querying table metadata.') + + table_filter = [] + if published_tag: + table_filter.append(Table.published_tag == published_tag) + + table_results = [] + + offset = 0 + tables = _table_search_query(session, table_filter, offset, limit) + while tables: + for table in tables: + schema = table.schema + schema_description = schema.description.description if schema.description else None + cluster = schema.cluster + database = cluster.database + description = table.description.description if table.description else '' + programmatic_descriptions = [description.description + for description in table.programmatic_descriptions] + + columns = table.columns + column_names = [column.name for column in columns] + column_descriptions = [column.description.description if column.description else '' + for column in columns] + + total_usage = sum(usage.read_count for usage in table.usage) + unique_usage = len(table.usage) + + tags = [tag.rk for tag in table.tags if tag.tag_type == 'default'] + badges = [badge.rk for badge in table.badges] + last_updated_timestamp = table.timestamp.last_updated_timestamp if table.timestamp else None + + table_result = dict(database=database.name, + cluster=cluster.name, + schema=schema.name, + name=table.name, + key=table.rk, + description=description, + last_updated_timestamp=last_updated_timestamp, + column_names=column_names, + column_descriptions=column_descriptions, + total_usage=total_usage, + unique_usage=unique_usage, + tags=tags, + badges=badges, + schema_description=schema_description, + programmatic_descriptions=programmatic_descriptions) + table_results.append(table_result) + + offset += limit + tables = _table_search_query(session, table_filter, offset, limit) + + return table_results + + +def _dashboard_search_query(session: Session, dashboard_filter: List, offset: int, limit: int) -> List: + """ + Dashboard query + :param session: + :param dashboard_filter: + :param offset: + :param limit: + :return: + """ + # dashboard + query = session.query(Dashboard).filter(*dashboard_filter).options( + load_only(Dashboard.rk, + Dashboard.name, + Dashboard.dashboard_url, + Dashboard.dashboard_group_rk) + ) + + # group, cluster + query = query.options( + subqueryload(Dashboard.group).options( + load_only(DashboardGroup.rk, + DashboardGroup.name, + DashboardGroup.dashboard_group_url, + DashboardGroup.cluster_rk), + subqueryload(DashboardGroup.description).options( + load_only(DashboardGroupDescription.description) + ), + subqueryload(DashboardGroup.cluster).options( + load_only(DashboardCluster.name) + ) + ) + ) + + # description + query = query.options( + subqueryload(Dashboard.description).options( + load_only(DashboardDescription.description) + ) + ) + + # execution + query = query.options( + subqueryload(Dashboard.execution).options( + load_only(DashboardExecution.rk, DashboardExecution.timestamp) + ) + ) + + # usage + query = query.options( + subqueryload(Dashboard.usage).options( + load_only(DashboardUsage.read_count) + ) + ) + + # query, chart + query = query.options( + subqueryload(Dashboard.queries).options( + load_only(DashboardQuery.name), + subqueryload(DashboardQuery.charts).options( + load_only(DashboardChart.name) + ) + ) + ) + + # tag, badge + query = query.options( + subqueryload(Dashboard.tags).options( + load_only(Tag.rk, Tag.tag_type) + ) + ).options( + subqueryload(Dashboard.badges).options( + load_only(Badge.rk) + ) + ) + + query = query.order_by(Dashboard.rk).offset(offset).limit(limit) + + return query.all() + + +def _dashboard_search(session: Session, published_tag: str, limit: int) -> List[Dict]: + """ + Query dashboard metadata. + :param session: + :param published_tag: + :param limit: + :return: + """ + LOGGER.info('Querying dashboard metadata.') + + dashboard_filter = [] + if published_tag: + dashboard_filter.append(Dashboard.published_tag == published_tag) + + dashboard_results = [] + + offset = 0 + dashboards = _dashboard_search_query(session, dashboard_filter, offset, limit) + while dashboards: + for dashboard in dashboards: + group = dashboard.group + description = dashboard.description.description if dashboard.description else None + group_description = group.description.description if group.description else None + cluster = group.cluster + product = dashboard.rk.split('_')[0] + last_exec = next((execution for execution in dashboard.execution + if execution.rk.endswith('_last_successful_execution')), None) + last_successful_run_timestamp = last_exec.timestamp if last_exec else None + total_usage = sum(usage.read_count for usage in dashboard.usage) + + queries = dashboard.queries + query_names = [query.name for query in queries] + chart_names = [chart.name for query in queries for chart in query.charts] + + tags = [tag.rk for tag in dashboard.tags if tag.tag_type == 'default'] + badges = [badge.rk for badge in dashboard.badges] + + dashboard_result = dict(group_name=group.name, + name=dashboard.name, + description=description, + total_usage=total_usage, + product=product, + cluster=cluster.name, + group_description=group_description, + query_names=query_names, + chart_names=chart_names, + group_url=group.dashboard_group_url, + url=dashboard.dashboard_url, + uri=dashboard.rk, + last_successful_run_timestamp=last_successful_run_timestamp, + tags=tags, + badges=badges) + + dashboard_results.append(dashboard_result) + + offset += limit + dashboards = _dashboard_search_query(session, dashboard_filter, offset, limit) + + return dashboard_results + + +def _user_search_query(session: Session, user_filter: List, offset: int, limit: int) -> List: + """ + User query + :param session: + :param user_filter: + :param offset: + :param limit: + :return: + """ + # read + table_usage_subquery = session \ + .query(User.rk, func.sum(TableUsage.read_count).label('table_read_count')) \ + .outerjoin(TableUsage) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + dashboard_usage_subquery = session \ + .query(User.rk, func.sum(DashboardUsage.read_count).label('dashboard_read_count')) \ + .outerjoin(DashboardUsage) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + # own + table_owner_subquery = session \ + .query(User.rk, func.count(TableOwner.table_rk).label('table_own_count')) \ + .outerjoin(TableOwner) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + dashboard_owner_subquery = session \ + .query(User.rk, func.count(DashboardOwner.dashboard_rk).label('dashboard_own_count')) \ + .outerjoin(DashboardOwner) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + # follow + table_follower_subquery = session \ + .query(User.rk, func.count(TableFollower.table_rk).label('table_follow_count')) \ + .outerjoin(TableFollower) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + dashboard_follower_subquery = session \ + .query(User.rk, func.count(DashboardFollower.dashboard_rk).label('dashboard_follow_count')) \ + .outerjoin(DashboardFollower) \ + .filter(*user_filter) \ + .group_by(User.rk).order_by(User.rk).offset(offset).limit(limit).subquery() + + # user + query = session \ + .query(User, + table_usage_subquery.c.table_read_count, + dashboard_usage_subquery.c.dashboard_read_count, + table_owner_subquery.c.table_own_count, + dashboard_owner_subquery.c.dashboard_own_count, + table_follower_subquery.c.table_follow_count, + dashboard_follower_subquery.c.dashboard_follow_count) \ + .join(table_usage_subquery, table_usage_subquery.c.rk == User.rk) \ + .join(dashboard_usage_subquery, dashboard_usage_subquery.c.rk == User.rk) \ + .join(table_owner_subquery, table_owner_subquery.c.rk == User.rk) \ + .join(dashboard_owner_subquery, dashboard_owner_subquery.c.rk == User.rk) \ + .join(table_follower_subquery, table_follower_subquery.c.rk == User.rk) \ + .join(dashboard_follower_subquery, dashboard_follower_subquery.c.rk == User.rk) + + # manager + query = query.options( + subqueryload(User.manager).options( + load_only(User.email) + ) + ) + + return query.all() + + +def _user_search(session: Session, published_tag: str, limit: int) -> List[Dict]: + """ + Query user metadata. + :param session: + :param published_tag: + :param limit: + :return: + """ + LOGGER.info('Querying user metadata.') + + user_filter = [User.full_name.isnot(None)] + if published_tag: + user_filter.append(User.published_tag == published_tag) + + user_results = [] + + offset = 0 + query_results = _user_search_query(session, user_filter, offset, limit) + while query_results: + for query_result in query_results: + user = query_result.User + table_read_count = int(query_result.table_read_count) if query_result.table_read_count else 0 + dashboard_read_count = int(query_result.dashboard_read_count) if query_result.dashboard_read_count else 0 + total_read_count = table_read_count + dashboard_read_count + + table_own_count = query_result.table_own_count + dashboard_own_count = query_result.dashboard_own_count + total_own_count = table_own_count + dashboard_own_count + + table_follow_count = query_result.table_follow_count + dashboard_follow_count = query_result.dashboard_follow_count + total_follow_count = table_follow_count + dashboard_follow_count + + manager_email = user.manager.email if user.manager else '' + user_result = dict(email=user.email, + first_name=user.first_name, + last_name=user.last_name, + full_name=user.full_name, + github_username=user.github_username, + team_name=user.team_name, + employee_type=user.employee_type, + manager_email=manager_email, + slack_id=user.slack_id, + role_name=user.role_name, + is_active=user.is_active, + total_read=total_read_count, + total_own=total_own_count, + total_follow=total_follow_count) + + user_results.append(user_result) + + offset += limit + query_results = _user_search_query(session, user_filter, offset, limit) + + return user_results + + +class MySQLSearchDataExtractor(Extractor): + """ + Extractor to fetch data required to support search from MySQL. + """ + ENTITY_TYPE = 'entity_type' + MODEL_CLASS = 'model_class' + JOB_PUBLISH_TAG = 'job_publish_tag' + SEARCH_FUNCTION = 'search_function' + + CONN_STRING = 'conn_string' + ENGINE_ECHO = 'engine_echo' + CONNECT_ARGS = 'connect_args' + QUERY_LIMIT = 'query_limit' + + _DEFAULT_QUERY_LIMIT = 500 + _DEFAULT_SEARCH_BY_ENTITY: Dict[str, Callable] = { + 'table': _table_search, + 'user': _user_search, + 'dashboard': _dashboard_search + } + + def init(self, conf: ConfigTree) -> None: + self.conf = conf + self.entity = conf.get_string(MySQLSearchDataExtractor.ENTITY_TYPE, default='table').lower() + if MySQLSearchDataExtractor.SEARCH_FUNCTION in conf: + self.search_function = conf.get(MySQLSearchDataExtractor.SEARCH_FUNCTION) + else: + self.search_function = MySQLSearchDataExtractor._DEFAULT_SEARCH_BY_ENTITY[self.entity] + self.published_tag = conf.get_string(MySQLSearchDataExtractor.JOB_PUBLISH_TAG, '') + self.query_limit = conf.get_int(MySQLSearchDataExtractor.QUERY_LIMIT, self._DEFAULT_QUERY_LIMIT) + + connect_args = {k: v for k, v in self.conf.get_config(MySQLSearchDataExtractor.CONNECT_ARGS, + default=ConfigTree()).items()} + self._engine = create_engine(conf.get_string(MySQLSearchDataExtractor.CONN_STRING), + echo=conf.get_bool(MySQLSearchDataExtractor.ENGINE_ECHO, False), + connect_args=connect_args) + self._session_factory = sessionmaker(bind=self._engine) + + model_class = conf.get(MySQLSearchDataExtractor.MODEL_CLASS, None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + + self._extract_iter: Optional[Iterator] = None + + def close(self) -> None: + """ + Close connection to mysql. + """ + try: + self._engine.dispose() + except Exception as e: + LOGGER.error(f'Exception encountered while closing engine: {e}') + + def extract(self) -> Optional[Any]: + """ + Return an object or a raw query result. + """ + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[Any]: + if not hasattr(self, 'results'): + session = self._session_factory() + try: + self.results = self.search_function(session=session, + published_tag=self.published_tag, + limit=self.query_limit) + except Exception as e: + LOGGER.exception('Exception encountered while executing the search function.') + raise e + finally: + session.close() + + for result in self.results: + if hasattr(self, 'model_class'): + obj = self.model_class(**result) + yield obj + else: + yield result + + def get_scope(self) -> str: + return 'extractor.mysql_search_data' diff --git a/databuilder/databuilder/extractor/neo4j_extractor.py b/databuilder/databuilder/extractor/neo4j_extractor.py new file mode 100644 index 0000000000..8d0cfb186e --- /dev/null +++ b/databuilder/databuilder/extractor/neo4j_extractor.py @@ -0,0 +1,136 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from typing import ( + Any, Iterator, Union, +) + +import neo4j +from neo4j import GraphDatabase +from neo4j.api import ( + SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri, +) +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor + +LOGGER = logging.getLogger(__name__) + + +class Neo4jExtractor(Extractor): + """ + Extractor to fetch records from Neo4j Graph database + """ + CYPHER_QUERY_CONFIG_KEY = 'cypher_query' + GRAPH_URL_CONFIG_KEY = 'graph_url' + MODEL_CLASS_CONFIG_KEY = 'model_class' + NEO4J_AUTH_USER = 'neo4j_auth_user' + NEO4J_AUTH_PW = 'neo4j_auth_pw' + # in Neo4j (v4.0+), we can create and use more than one active database at the same time + NEO4J_DATABASE_NAME = 'neo4j_database' + NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec' + NEO4J_ENCRYPTED = 'neo4j_encrypted' + """NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting.""" + NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' + """NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs.""" + + DEFAULT_CONFIG = ConfigFactory.from_dict({ + NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, + NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE, + }) + + def init(self, conf: ConfigTree) -> None: + """ + Establish connections and import data model class if provided + :param conf: + """ + self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG) + self.graph_url = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) + self.cypher_query = self.conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY) + self.db_name = self.conf.get_string(Neo4jExtractor.NEO4J_DATABASE_NAME) + + uri = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER), + self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = self.conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None) + encrypted_conf = self.conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + self.driver = GraphDatabase.driver(**driver_args) + + self._extract_iter: Union[None, Iterator] = None + + model_class = self.conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + + def close(self) -> None: + """ + close connection to neo4j cluster + """ + try: + self.driver.close() + except Exception as e: + LOGGER.error("Exception encountered while closing the graph driver", e) + + def _execute_query(self, tx: Any) -> Any: + """ + Create an iterator to execute sql. + """ + LOGGER.info('Executing query %s', self.cypher_query) + result = tx.run(self.cypher_query) + return [record for record in result] + + def _get_extract_iter(self) -> Iterator[Any]: + """ + Execute {cypher_query} and yield result one at a time + """ + with self.driver.session( + database=self.db_name + ) as session: + if not hasattr(self, 'results'): + self.results = session.read_transaction(self._execute_query) + + for result in self.results: + if hasattr(self, 'model_class'): + obj = self.model_class(**result) + yield obj + else: + yield result + + def extract(self) -> Any: + """ + Return {result} object as it is or convert to object of + {model_class}, if specified. + """ + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.neo4j' diff --git a/databuilder/databuilder/extractor/neo4j_search_data_extractor.py b/databuilder/databuilder/extractor/neo4j_search_data_extractor.py new file mode 100644 index 0000000000..2e02433203 --- /dev/null +++ b/databuilder/databuilder/extractor/neo4j_search_data_extractor.py @@ -0,0 +1,202 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import textwrap +from typing import Any + +from pyhocon import ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.neo4j_extractor import Neo4jExtractor +from databuilder.publisher.neo4j_csv_publisher import JOB_PUBLISH_TAG + + +class Neo4jSearchDataExtractor(Extractor): + """ + Extractor to fetch data required to support search from Neo4j graph database + Use Neo4jExtractor extractor class + """ + CYPHER_QUERY_CONFIG_KEY = 'cypher_query' + ENTITY_TYPE = 'entity_type' + + DEFAULT_NEO4J_TABLE_CYPHER_QUERY = textwrap.dedent( + """ + MATCH (db:Database)<-[:CLUSTER_OF]-(cluster:Cluster) + <-[:SCHEMA_OF]-(schema:Schema)<-[:TABLE_OF]-(table:Table) + {publish_tag_filter} + OPTIONAL MATCH (table)-[:DESCRIPTION]->(table_description:Description) + OPTIONAL MATCH (schema)-[:DESCRIPTION]->(schema_description:Description) + OPTIONAL MATCH (table)-[:DESCRIPTION]->(prog_descs:Programmatic_Description) + WITH db, cluster, schema, schema_description, table, table_description, + COLLECT(prog_descs.description) as programmatic_descriptions + OPTIONAL MATCH (table)-[:TAGGED_BY]->(tags:Tag) WHERE tags.tag_type='default' + WITH db, cluster, schema, schema_description, table, table_description, programmatic_descriptions, + COLLECT(DISTINCT tags.key) as tags + OPTIONAL MATCH (table)-[:HAS_BADGE]->(badges:Badge) + WITH db, cluster, schema, schema_description, table, table_description, programmatic_descriptions, tags, + COLLECT(DISTINCT badges.key) as badges + OPTIONAL MATCH (table)-[read:READ_BY]->(user:User) + WITH db, cluster, schema, schema_description, table, table_description, programmatic_descriptions, tags, badges, + SUM(read.read_count) AS total_usage, + COUNT(DISTINCT user.email) as unique_usage + OPTIONAL MATCH (table)-[:COLUMN]->(col:Column) + OPTIONAL MATCH (col)-[:DESCRIPTION]->(col_description:Description) + WITH db, cluster, schema, schema_description, table, table_description, tags, badges, total_usage, unique_usage, + programmatic_descriptions, + COLLECT(col.name) AS column_names, COLLECT(col_description.description) AS column_descriptions + OPTIONAL MATCH (table)-[:LAST_UPDATED_AT]->(time_stamp:Timestamp) + RETURN db.name as database, cluster.name AS cluster, schema.name AS schema, + schema_description.description AS schema_description, + table.name AS name, table.key AS key, table_description.description AS description, + time_stamp.last_updated_timestamp AS last_updated_timestamp, + column_names, + column_descriptions, + total_usage, + unique_usage, + tags, + badges, + programmatic_descriptions + ORDER BY table.name; + """ + ) + + DEFAULT_NEO4J_USER_CYPHER_QUERY = textwrap.dedent( + """ + MATCH (user:User) + OPTIONAL MATCH (user)-[read:READ]->(a) + OPTIONAL MATCH (user)-[own:OWNER_OF]->(b) + OPTIONAL MATCH (user)-[follow:FOLLOWED_BY]->(c) + OPTIONAL MATCH (user)-[manage_by:MANAGE_BY]->(manager) + {publish_tag_filter} + with user, a, b, c, read, own, follow, manager + where user.full_name is not null + return user.email as email, user.first_name as first_name, user.last_name as last_name, + user.full_name as full_name, user.github_username as github_username, user.team_name as team_name, + user.employee_type as employee_type, manager.email as manager_email, + user.slack_id as slack_id, user.is_active as is_active, user.role_name as role_name, + REDUCE(sum_r = 0, r in COLLECT(DISTINCT read)| sum_r + r.read_count) AS total_read, + count(distinct b) as total_own, + count(distinct c) AS total_follow + order by user.email + """ + ) + + DEFAULT_NEO4J_DASHBOARD_CYPHER_QUERY = textwrap.dedent( + """ + MATCH (dashboard:Dashboard) + {publish_tag_filter} + MATCH (dashboard)-[:DASHBOARD_OF]->(dbg:Dashboardgroup) + MATCH (dbg)-[:DASHBOARD_GROUP_OF]->(cluster:Cluster) + OPTIONAL MATCH (dashboard)-[:DESCRIPTION]->(db_descr:Description) + OPTIONAL MATCH (dbg)-[:DESCRIPTION]->(dbg_descr:Description) + OPTIONAL MATCH (dashboard)-[:EXECUTED]->(last_exec:Execution) + WHERE split(last_exec.key, '/')[5] = '_last_successful_execution' + OPTIONAL MATCH (dashboard)-[read:READ_BY]->(user:User) + WITH dashboard, dbg, db_descr, dbg_descr, cluster, last_exec, SUM(read.read_count) AS total_usage + OPTIONAL MATCH (dashboard)-[:HAS_QUERY]->(query:Query)-[:HAS_CHART]->(chart:Chart) + WITH dashboard, dbg, db_descr, dbg_descr, cluster, last_exec, COLLECT(DISTINCT query.name) as query_names, + COLLECT(DISTINCT chart.name) as chart_names, + total_usage + OPTIONAL MATCH (dashboard)-[:TAGGED_BY]->(tags:Tag) WHERE tags.tag_type='default' + WITH dashboard, dbg, db_descr, dbg_descr, cluster, last_exec, query_names, chart_names, total_usage, + COLLECT(DISTINCT tags.key) as tags + OPTIONAL MATCH (dashboard)-[:HAS_BADGE]->(badges:Badge) + WITH dashboard, dbg, db_descr, dbg_descr, cluster, last_exec, query_names, chart_names, total_usage, tags, + COLLECT(DISTINCT badges.key) as badges + RETURN dbg.name as group_name, dashboard.name as name, cluster.name as cluster, + coalesce(db_descr.description, '') as description, + coalesce(dbg.description, '') as group_description, dbg.dashboard_group_url as group_url, + dashboard.dashboard_url as url, dashboard.key as uri, + split(dashboard.key, '_')[0] as product, toInteger(last_exec.timestamp) as last_successful_run_timestamp, + query_names, chart_names, total_usage, tags, badges + order by dbg.name + """ + ) + + DEFAULT_NEO4J_FEATURE_CYPHER_QUERY = textwrap.dedent( + """ + MATCH (feature:Feature) + {publish_tag_filter} + OPTIONAL MATCH (fg:Feature_Group)-[:GROUPS]->(feature) + OPTIONAL MATCH (db:Database)-[:AVAILABLE_FEATURE]->(feature) + OPTIONAL MATCH (feature)-[:DESCRIPTION]->(desc:Description) + OPTIONAL MATCH (feature)-[:TAGGED_BY]->(tag:Tag) + OPTIONAL MATCH (feature)-[:HAS_BADGE]->(badge:Badge) + OPTIONAL MATCH (feature)-[read:READ_BY]->(user:User) + RETURN + fg.name as feature_group, + feature.name as feature_name, + feature.version as version, + feature.key as key, + SUM(read.read_count) AS total_usage, + feature.status as status, + feature.entity as entity, + desc.description as description, + db.name as availability, + COLLECT(DISTINCT badge.key) as badges, + COLLECT(DISTINCT tag.key) as tags, + toInteger(feature.last_updated_timestamp) as last_updated_timestamp + order by fg.name, feature.name, feature.version + """ + ) + + DEFAULT_QUERY_BY_ENTITY = { + 'table': DEFAULT_NEO4J_TABLE_CYPHER_QUERY, + 'user': DEFAULT_NEO4J_USER_CYPHER_QUERY, + 'dashboard': DEFAULT_NEO4J_DASHBOARD_CYPHER_QUERY, + 'feature': DEFAULT_NEO4J_FEATURE_CYPHER_QUERY, + } + + def init(self, conf: ConfigTree) -> None: + """ + Initialize Neo4jExtractor object from configuration and use that for extraction + """ + self.conf = conf + self.entity = conf.get_string(Neo4jSearchDataExtractor.ENTITY_TYPE, default='table').lower() + # extract cypher query from conf, if specified, else use default query + if Neo4jSearchDataExtractor.CYPHER_QUERY_CONFIG_KEY in conf: + self.cypher_query = conf.get_string(Neo4jSearchDataExtractor.CYPHER_QUERY_CONFIG_KEY) + else: + default_query = Neo4jSearchDataExtractor.DEFAULT_QUERY_BY_ENTITY[self.entity] + self.cypher_query = self._add_publish_tag_filter(conf.get_string(JOB_PUBLISH_TAG, ''), + cypher_query=default_query) + + self.neo4j_extractor = Neo4jExtractor() + # write the cypher query in configs in Neo4jExtractor scope + key = self.neo4j_extractor.get_scope() + '.' + Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY + self.conf.put(key, self.cypher_query) + # initialize neo4j_extractor from configs + self.neo4j_extractor.init(Scoped.get_scoped_conf(self.conf, self.neo4j_extractor.get_scope())) + + def close(self) -> None: + """ + Use close() method specified by neo4j_extractor + to close connection to neo4j cluster + """ + self.neo4j_extractor.close() + + def extract(self) -> Any: + """ + Invoke extract() method defined by neo4j_extractor + """ + return self.neo4j_extractor.extract() + + def get_scope(self) -> str: + return 'extractor.search_data' + + def _add_publish_tag_filter(self, publish_tag: str, cypher_query: str) -> str: + """ + Adds publish tag filter into Cypher query + :param publish_tag: value of publish tag. + :param cypher_query: + :return: + """ + + if not publish_tag: + publish_tag_filter = '' + else: + if not hasattr(self, 'entity'): + self.entity = 'table' + publish_tag_filter = f"WHERE {self.entity}.published_tag = '{publish_tag}'" + return cypher_query.format(publish_tag_filter=publish_tag_filter) diff --git a/databuilder/databuilder/extractor/neptune_search_data_extractor.py b/databuilder/databuilder/extractor/neptune_search_data_extractor.py new file mode 100644 index 0000000000..ef8a1edbf4 --- /dev/null +++ b/databuilder/databuilder/extractor/neptune_search_data_extractor.py @@ -0,0 +1,294 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import ( + Any, Dict, List, Optional, +) + +from gremlin_python.process.graph_traversal import GraphTraversalSource, __ +from gremlin_python.process.traversal import ( + Order, T, TextP, +) +from pyhocon import ConfigTree + +from databuilder import Scoped +from databuilder.clients.neptune_client import NeptuneSessionClient +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.cluster.cluster_constants import CLUSTER_REVERSE_RELATION_TYPE +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.models.owner_constants import OWNER_OF_OBJECT_RELATION_TYPE +from databuilder.models.schema.schema_constant import SCHEMA_REVERSE_RELATION_TYPE +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.timestamp.timestamp_constants import LASTUPDATED_RELATION_TYPE, TIMESTAMP_PROPERTY +from databuilder.models.usage.usage_constants import ( + READ_RELATION_COUNT_PROPERTY, READ_RELATION_TYPE, READ_REVERSE_RELATION_TYPE, +) +from databuilder.models.user import User +from databuilder.serializers.neptune_serializer import METADATA_KEY_PROPERTY_NAME + + +def _table_search_query(graph: GraphTraversalSource, tag_filter: str) -> List[Dict]: + traversal = graph.V().hasLabel(TableMetadata.TABLE_NODE_LABEL) + if tag_filter: + traversal = traversal.has('published_tag', tag_filter) + traversal = traversal.project( + 'database', + 'cluster', + 'schema', + 'schema_description', + 'name', + 'key', + 'description', + 'last_updated_timestamp', + 'column_names', + 'column_descriptions', + 'total_usage', + 'unique_usage', + 'tags', + 'badges', + 'programmatic_descriptions' + ) + traversal = traversal.by( + __.out( + TableMetadata.TABLE_SCHEMA_RELATION_TYPE + ).out(SCHEMA_REVERSE_RELATION_TYPE).out(CLUSTER_REVERSE_RELATION_TYPE).values('name') + ) # database + traversal = traversal.by( + __.out(TableMetadata.TABLE_SCHEMA_RELATION_TYPE).out(SCHEMA_REVERSE_RELATION_TYPE).values('name') + ) # cluster + traversal = traversal.by(__.out(TableMetadata.TABLE_SCHEMA_RELATION_TYPE).values('name')) # schema + traversal = traversal.by(__.coalesce( + __.out(TableMetadata.TABLE_SCHEMA_RELATION_TYPE).out( + DescriptionMetadata.DESCRIPTION_RELATION_TYPE + ).values('description'), + __.constant('') + )) # schema_description + traversal = traversal.by('name') # name + traversal = traversal.by(T.id) # key + traversal = traversal.by(__.coalesce( + __.out(DescriptionMetadata.DESCRIPTION_RELATION_TYPE).values('description'), + __.constant('') + )) # description + traversal = traversal.by( + __.coalesce(__.out(LASTUPDATED_RELATION_TYPE).values(TIMESTAMP_PROPERTY), __.constant('')) + ) # last_updated_timestamp + traversal = traversal.by(__.out(TableMetadata.TABLE_COL_RELATION_TYPE).values('name').fold()) # column_names + traversal = traversal.by( + __.out(TableMetadata.TABLE_COL_RELATION_TYPE).out( + DescriptionMetadata.DESCRIPTION_RELATION_TYPE + ).values('description').fold() + ) # column_descriptions + traversal = traversal.by(__.coalesce( + __.outE(READ_REVERSE_RELATION_TYPE).values('read_count'), + __.constant(0)).sum() + ) # total_usage + traversal = traversal.by(__.outE(READ_REVERSE_RELATION_TYPE).count()) # unique_usage + traversal = traversal.by( + __.inE(TableMetadata.TAG_TABLE_RELATION_TYPE).outV().values(METADATA_KEY_PROPERTY_NAME).fold() + ) # tags + traversal = traversal.by( + __.out('HAS_BADGE').values('keys').dedup().fold() + ) # badges + traversal = traversal.by( + __.out(DescriptionMetadata.PROGRAMMATIC_DESCRIPTION_NODE_LABEL).values('description').fold() + ) # programmatic_descriptions + traversal = traversal.order().by(__.select('name'), Order.asc) + return traversal.toList() + + +def _user_search_query(graph: GraphTraversalSource, tag_filter: str) -> List[Dict]: + traversal = graph.V().hasLabel(User.USER_NODE_LABEL) + traversal = traversal.has(User.USER_NODE_FULL_NAME) + if tag_filter: + traversal = traversal.where('published_tag', tag_filter) + traversal = traversal.project( + 'email', + 'first_name', + 'last_name', + 'full_name', + 'github_username', + 'team_name', + 'employee_type', + 'manager_email', + 'slack_id', + 'is_active', + 'role_name', + 'total_read', + 'total_own', + 'total_follow' + ) + traversal = traversal.by('email') # email + traversal = traversal.by('first_name') # first_name + traversal = traversal.by('last_name') # last_name + traversal = traversal.by('full_name') # full_name + traversal = traversal.by('github_username') # github_username + traversal = traversal.by('team_name') # team_name + traversal = traversal.by('employee_type') # employee_type + traversal = traversal.by(__.coalesce( + __.out(User.USER_MANAGER_RELATION_TYPE).values('email'), + __.constant('')) + ) # manager_email + traversal = traversal.by('slack_id') # slack_id + traversal = traversal.by('is_active') # is_active + traversal = traversal.by('role_name') # role_name + traversal = traversal.by(__.coalesce( + __.outE(READ_RELATION_TYPE).values('read_count'), + __.constant(0) + ).sum()) # total_read + traversal = traversal.by(__.outE(OWNER_OF_OBJECT_RELATION_TYPE).fold().count()) # total_own + traversal = traversal.by(__.outE('FOLLOWED_BY').fold().count()) # total_follow + traversal = traversal.order().by(__.select('email'), Order.asc) + return traversal.toList() + + +def _dashboard_search_query(graph: GraphTraversalSource, tag_filter: str) -> List[Dict]: + traversal = graph.V().hasLabel(DashboardMetadata.DASHBOARD_NODE_LABEL) + traversal = traversal.has('name') + if tag_filter: + traversal = traversal.where('published_tag', tag_filter) + + traversal = traversal.project( + 'group_name', + 'name', + 'cluster', + 'description', + 'group_description', + 'group_url', + 'url', + 'uri', + 'last_successful_run_timestamp', + 'query_names', + 'chart_names', + 'total_usage', + 'tags', + 'badges' + ) + traversal = traversal.by( + __.out(DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE).values('name') + ) # group_name + traversal = traversal.by('name') # name + traversal = traversal.by( + __.out( + DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE + ).out( + DashboardMetadata.DASHBOARD_GROUP_CLUSTER_RELATION_TYPE + ).values('name') + ) # cluster + traversal = traversal.by(__.coalesce( + __.out(DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE).values('description'), + __.constant('') + )) # description + traversal = traversal.by(__.coalesce( + __.out(DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE).out( + DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE + ).values('description'), + __.constant('') + )) # group_description + traversal = traversal.by( + __.out(DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE).values('dashboard_group_url') + ) # group_url + traversal = traversal.by('dashboard_url') # dashboard_url + traversal = traversal.by('key') # uri + + traversal = traversal.by( + __.coalesce( + __.out('EXECUTED').has('key', TextP.endingWith('_last_successful_execution')).values('timestamp'), + __.constant('') + ) + ) # last_successful_run_timestamp + traversal = traversal.by( + __.out(DashboardQuery.DASHBOARD_QUERY_RELATION_TYPE).values('name').dedup().fold() + ) # query_names + traversal = traversal.by( + __.out( + DashboardQuery.DASHBOARD_QUERY_RELATION_TYPE + ).out(DashboardChart.CHART_RELATION_TYPE).values('name').dedup().fold() + ) # chart_names + traversal = traversal.by(__.coalesce( + __.outE(READ_REVERSE_RELATION_TYPE).values(READ_RELATION_COUNT_PROPERTY), + __.constant(0) + ).sum()) # total_usage + traversal = traversal.by( + __.out('TAGGED_BY').has('tag_type', 'default').values('keys').dedup().fold() + ) # tags + traversal = traversal.by( + __.out('HAS_BADGE').values('keys').dedup().fold() + ) # badges + + traversal = traversal.order().by(__.select('name'), Order.asc) + + dashboards = traversal.toList() + for dashboard in dashboards: + dashboard['product'] = dashboard['uri'].split('_')[0] + + return dashboards + + +class NeptuneSearchDataExtractor(Extractor): + """ + Extractor to fetch data required to support search from Neptune's graph database + """ + QUERY_FUNCTION_CONFIG_KEY = 'query_function' + QUERY_FUNCTION_KWARGS_CONFIG_KEY = 'query_function_kwargs' + ENTITY_TYPE_CONFIG_KEY = 'entity_type' + JOB_PUBLISH_TAG_CONFIG_KEY = 'job_publish_tag' + MODEL_CLASS_CONFIG_KEY = 'model_class' + + DEFAULT_QUERY_BY_ENTITY = { + 'table': _table_search_query, + 'user': _user_search_query, + 'dashboard': _dashboard_search_query + } + + def init(self, conf: ConfigTree) -> None: + self.conf = conf + self.entity = conf.get_string(NeptuneSearchDataExtractor.ENTITY_TYPE_CONFIG_KEY, default='table').lower() + + if NeptuneSearchDataExtractor.QUERY_FUNCTION_CONFIG_KEY in conf: + self.query_function = conf.get(NeptuneSearchDataExtractor.QUERY_FUNCTION_CONFIG_KEY) + else: + self.query_function = NeptuneSearchDataExtractor.DEFAULT_QUERY_BY_ENTITY[self.entity] + + self.job_publish_tag = conf.get_string(NeptuneSearchDataExtractor.JOB_PUBLISH_TAG_CONFIG_KEY, '') + self.neptune_client = NeptuneSessionClient() + + neptune_client_conf = Scoped.get_scoped_conf(conf, self.neptune_client.get_scope()) + self.neptune_client.init(neptune_client_conf) + + model_class = conf.get(NeptuneSearchDataExtractor.MODEL_CLASS_CONFIG_KEY, None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + + self._extract_iter: Optional[Any] = None + + def close(self) -> None: + self.neptune_client.close() + + def extract(self) -> Optional[Any]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Any: + if not hasattr(self, 'results'): + self.results = self.query_function(self.neptune_client.get_graph(), tag_filter=self.job_publish_tag) + + for result in self.results: + if hasattr(self, 'model_class'): + obj = self.model_class(**result) + yield obj + else: + yield result + + def get_scope(self) -> str: + return 'extractor.neptune_search_data' diff --git a/databuilder/databuilder/extractor/openlineage_extractor.py b/databuilder/databuilder/extractor/openlineage_extractor.py new file mode 100644 index 0000000000..cd3323b758 --- /dev/null +++ b/databuilder/databuilder/extractor/openlineage_extractor.py @@ -0,0 +1,106 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from typing import ( + Any, Dict, Iterator, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_lineage import TableLineage + +LOGGER = logging.getLogger(__name__) + + +class OpenLineageTableLineageExtractor(Extractor): + # Config keys + TABLE_LINEAGE_FILE_LOCATION = 'table_lineage_file_location' + CLUSTER_NAME = 'cluster_name' + OL_DATASET_NAMESPACE_OVERRIDE = 'namespace_override' + # Openlineage values key's, which will be used to extract data from an OpenLineage event + OL_INPUTS_KEY = 'inputs_key' + OL_OUTPUTS_KEY = 'outputs_key' + OL_DATASET_NAMESPACE_KEY = 'namespace_key' + OL_DATASET_DATABASE_KEY = 'database_key' + OL_DATASET_NAME_KEY = 'dataset_name_key' + + """ + An Extractor that creates Table Lineage between two tables based on OpenLineage event + """ + + def init(self, conf: ConfigTree) -> None: + """ + :param conf: + """ + self.conf = conf + self.table_lineage_file_location = conf.get_string(OpenLineageTableLineageExtractor.TABLE_LINEAGE_FILE_LOCATION) + self.cluster_name = conf.get_string(OpenLineageTableLineageExtractor.CLUSTER_NAME) + self.ol_inputs_key = conf.get_string(OpenLineageTableLineageExtractor.OL_INPUTS_KEY, default='inputs') + self.ol_outputs_key = conf.get_string(OpenLineageTableLineageExtractor.OL_OUTPUTS_KEY, default='outputs') + self.ol_namespace_key = conf.get_string( + OpenLineageTableLineageExtractor.OL_DATASET_NAMESPACE_KEY, default='namespace') + self.ol_database_key = conf.get_string( + OpenLineageTableLineageExtractor.OL_DATASET_DATABASE_KEY, default='database') + self.ol_dataset_name_key = conf.get_string( + OpenLineageTableLineageExtractor.OL_DATASET_NAME_KEY, default='name') + self.ol_namespace_override = conf.get_string( + OpenLineageTableLineageExtractor.OL_DATASET_NAMESPACE_OVERRIDE, default=None) + self._load_openlineage_event() + + def _extract_dataset_info(self, openlineage_event: Any) -> Iterator[Dict]: + """ + Yield input/output dict in form of amundsen table keys + """ + + for event in openlineage_event: + try: + in_and_outs = ((inputs, outputs) + for inputs in event[self.ol_inputs_key] + for outputs in event[self.ol_outputs_key]) + for row in in_and_outs: + yield {'input': self._amundsen_dataset_key(row[0]), + 'output': self._amundsen_dataset_key(row[1])} + except KeyError: + LOGGER.error(f'Cannot extract valid input or output from Openlineage event \n {event} ') + + def _amundsen_dataset_key(self, dataset: Dict) -> str: + """ + Generation of amundsen dataset key with optional namespace overriding. + Amundsen dataset key format: ://./
. + If dataset name is represented in path form ie. ( /warehouse/database/table ) + only last part of such path will be extracted as dataset name + """ + namespace = self.ol_namespace_override if self.ol_namespace_override else dataset[self.ol_namespace_key] + return f'{namespace}://{self.cluster_name}.{dataset[self.ol_database_key]}' \ + f'/{dataset[self.ol_dataset_name_key].split("/")[-1]}' + + def _load_openlineage_event(self) -> Any: + + self.input_file = open(self.table_lineage_file_location, 'r') + + lineage_event = (json.loads(line) for line in self.input_file) + + table_lineage = (TableLineage(table_key=lineage['input'], + downstream_deps=[lineage['output']]) + + for lineage in self._extract_dataset_info(lineage_event)) + self._iter = table_lineage + + def extract(self) -> Any: + """ + Yield the csv result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self._iter) + except StopIteration: + self.input_file.close() + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.openlineage_tablelineage' diff --git a/databuilder/databuilder/extractor/oracle_metadata_extractor.py b/databuilder/databuilder/extractor/oracle_metadata_extractor.py new file mode 100644 index 0000000000..a6d6099272 --- /dev/null +++ b/databuilder/databuilder/extractor/oracle_metadata_extractor.py @@ -0,0 +1,136 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class OracleMetadataExtractor(Extractor): + """ + Extracts Oracle table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', CLUSTER_KEY: DEFAULT_CLUSTER_NAME} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(OracleMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(OracleMetadataExtractor.CLUSTER_KEY, default='oracle') + + self._database = conf.get_string(OracleMetadataExtractor.DATABASE_KEY, default='oracle') + + self.sql_stmt = self.get_sql_statement( + where_clause_suffix=conf.get_string(OracleMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info('SQL for oracle metadata: %s', self.sql_stmt) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], row['col_description'], + row['col_type'], row['col_sort_order'])) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + last_row['description'], + columns) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None + + def get_sql_statement(self, where_clause_suffix: str) -> str: + cluster_source = f"'{self._cluster}'" + + return """ + SELECT + {cluster_source} as "cluster", + lower(c.owner) as "schema", + lower(c.table_name) as "name", + tc.comments as "description", + lower(c.column_name) as "col_name", + lower(c.data_type) as "col_type", + cc.comments as "col_description", + lower(c.column_id) as "col_sort_order" + FROM + all_tab_columns c + LEFT JOIN + all_tab_comments tc ON c.owner=tc.owner AND c.table_name=tc.table_name + LEFT JOIN + all_col_comments cc ON c.owner=cc.owner AND c.table_name=cc.table_name AND c.column_name=cc.column_name + {where_clause_suffix} + ORDER BY "cluster", "schema", "name", "col_sort_order" + """.format( + cluster_source=cluster_source, + where_clause_suffix=where_clause_suffix, + ) + + def get_scope(self) -> str: + return 'extractor.oracle_metadata' diff --git a/databuilder/databuilder/extractor/pandas_profiling_column_stats_extractor.py b/databuilder/databuilder/extractor/pandas_profiling_column_stats_extractor.py new file mode 100644 index 0000000000..1b8dd8d76d --- /dev/null +++ b/databuilder/databuilder/extractor/pandas_profiling_column_stats_extractor.py @@ -0,0 +1,193 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import ( + Any, Dict, Tuple, +) + +import dateutil.parser +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_stats import TableColumnStats + + +class PandasProfilingColumnStatsExtractor(Extractor): + FILE_PATH = 'file_path' + DATABASE_NAME = 'database_name' + TABLE_NAME = 'table_name' + SCHEMA_NAME = 'schema_name' + CLUSTER_NAME = 'cluster_name' + + # if you wish to collect only selected set of metrics configure stat_mappings option of the extractor providing + # similar dictionary but containing only keys of metrics you wish to collect. + # For example - if you want only min and max value of a column, provide extractor with configuration option: + # PandasProfilingColumnStatsExtractor.STAT_MAPPINGS = {'max': ('Maximum', float), 'min': ('Minimum', float)} + STAT_MAPPINGS = 'stat_mappings' + + # - key - raw name of the stat in pandas-profiling. Value - tuple of stat spec. + # - first value of the tuple - full name of the stat + # - second value of the tuple - function modifying the stat (by default we just do type casting) + DEFAULT_STAT_MAPPINGS = { + '25%': ('Quantile 25%', float), + '5%': ('Quantile 5%', float), + '50%': ('Quantile 50%', float), + '75%': ('Quantile 75%', float), + '95%': ('Quantile 95%', float), + 'chi_squared': ('Chi squared', lambda x: float(x.get('statistic'))), + 'count': ('Count', int), + 'is_unique': ('Unique', bool), + 'kurtosis': ('Kurtosis', float), + 'max': ('Maximum', str), + 'max_length': ('Maximum length', int), + 'mean': ('Mean', float), + 'mean_length': ('Mean length', int), + 'median_length': ('Median length', int), + 'min': ('Minimum', str), + 'min_length': ('Minimum length', int), + 'monotonic': ('Monotonic', bool), + 'n_characters': ('Characters', int), + 'n_characters_distinct': ('Distinct characters', int), + 'n_distinct': ('Distinct values', int), + 'n_infinite': ('Infinite values', int), + 'n_missing': ('Missing values', int), + 'n_negative': ('Negative values', int), + 'n_unique': ('Unique values', int), + 'n_zeros': ('Zeros', int), + 'p_distinct': ('Distinct values %', lambda x: float(x * 100)), + 'p_infinite': ('Infinite values %', lambda x: float(x * 100)), + 'p_missing': ('Missing values %', lambda x: float(x * 100)), + 'p_negative': ('Negative values %', lambda x: float(x * 100)), + 'p_unique': ('Unique values %', lambda x: float(x * 100)), + 'p_zeros': ('Zeros %', lambda x: float(x * 100)), + 'range': ('Range', str), + 'skewness': ('Skewness', float), + 'std': ('Std. deviation', float), + 'sum': ('Sum', float), + 'variance': ('Variance', float) + # Stats available in pandas-profiling but are not collected by default and require custom, conscious config.. + # 'block_alias_char_counts': ('',), + # 'block_alias_counts': ('',), + # 'block_alias_values': ('',), + # 'category_alias_char_counts': ('',), + # 'category_alias_counts': ('',), + # 'category_alias_values': ('',), + # 'character_counts': ('',), + # 'cv': ('',), + # 'first_rows': ('',), + # 'hashable': ('',), + # 'histogram': ('',), + # 'histogram_frequencies': ('',), + # 'histogram_length': ('',), + # 'iqr': ('',), + # 'length': ('',), + # 'mad': ('',), + # 'memory_size': ('',), + # 'monotonic_decrease': ('Monotonic decrease', bool), + # 'monotonic_decrease_strict': ('Strict monotonic decrease', bool), + # 'monotonic_increase': ('Monotonic increase', bool), + # 'monotonic_increase_strict': ('Strict monotonic increase', bool), + # 'n': ('',), + # 'n_block_alias': ('',), + # 'n_category': ('Categories', int), + # 'n_scripts': ('',), + # 'ordering': ('',), + # 'script_char_counts': ('',), + # 'script_counts': ('',), + # 'value_counts_index_sorted': ('',), + # 'value_counts_without_nan': ('',), + # 'word_counts': ('',), + # 'type': ('Type', str) + } + + PRECISION = 'precision' + + DEFAULT_CONFIG = ConfigFactory.from_dict({STAT_MAPPINGS: DEFAULT_STAT_MAPPINGS, PRECISION: 3}) + + def get_scope(self) -> str: + return 'extractor.pandas_profiling' + + def init(self, conf: ConfigTree) -> None: + self.conf = conf.with_fallback(PandasProfilingColumnStatsExtractor.DEFAULT_CONFIG) + + self._extract_iter = self._get_extract_iter() + + def extract(self) -> Any: + try: + result = next(self._extract_iter) + + return result + except StopIteration: + return None + + def _get_extract_iter(self) -> Any: + report = self._load_report() + + variables = report.get('variables', dict()) + report_time = self.parse_date(report.get('analysis', dict()).get('date_start')) + + for column_name, column_stats in variables.items(): + for _stat_name, stat_value in column_stats.items(): + stat_spec = self.stat_mappings.get(_stat_name) + + if stat_spec: + stat_name, stat_modifier = stat_spec + + if isinstance(stat_value, float): + stat_value = self.round_value(stat_value) + + stat = TableColumnStats(table_name=self.table_name, col_name=column_name, stat_name=stat_name, + stat_val=stat_modifier(stat_value), start_epoch=report_time, end_epoch='0', + db=self.database_name, cluster=self.cluster_name, schema=self.schema_name) + + yield stat + + def _load_report(self) -> Dict[str, Any]: + path = self.conf.get(PandasProfilingColumnStatsExtractor.FILE_PATH) + + try: + with open(path, 'r') as f: + _data = f.read() + + data = json.loads(_data) + + return data + except Exception: + return {} + + @staticmethod + def parse_date(string_date: str) -> str: + try: + date_parsed = dateutil.parser.parse(string_date) + + # date from pandas-profiling doesn't contain timezone so to be timezone safe we need to assume it's utc + if not date_parsed.tzname(): + return PandasProfilingColumnStatsExtractor.parse_date(f'{string_date}+0000') + + return str(int(date_parsed.timestamp())) + except Exception: + return '0' + + def round_value(self, value: float) -> float: + return round(value, self.conf.get(PandasProfilingColumnStatsExtractor.PRECISION)) + + @property + def stat_mappings(self) -> Dict[str, Tuple[str, Any]]: + return dict(self.conf.get(PandasProfilingColumnStatsExtractor.STAT_MAPPINGS)) + + @property + def cluster_name(self) -> str: + return self.conf.get(PandasProfilingColumnStatsExtractor.CLUSTER_NAME) + + @property + def database_name(self) -> str: + return self.conf.get(PandasProfilingColumnStatsExtractor.DATABASE_NAME) + + @property + def schema_name(self) -> str: + return self.conf.get(PandasProfilingColumnStatsExtractor.SCHEMA_NAME) + + @property + def table_name(self) -> str: + return self.conf.get(PandasProfilingColumnStatsExtractor.TABLE_NAME) diff --git a/databuilder/databuilder/extractor/postgres_metadata_extractor.py b/databuilder/databuilder/extractor/postgres_metadata_extractor.py new file mode 100644 index 0000000000..cb80565757 --- /dev/null +++ b/databuilder/databuilder/extractor/postgres_metadata_extractor.py @@ -0,0 +1,55 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( # noqa: F401 + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree # noqa: F401 + +from databuilder.extractor.base_postgres_metadata_extractor import BasePostgresMetadataExtractor + + +class PostgresMetadataExtractor(BasePostgresMetadataExtractor): + """ + Extracts Postgres table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + def get_sql_statement(self, use_catalog_as_cluster_name: bool, where_clause_suffix: str) -> str: + if use_catalog_as_cluster_name: + cluster_source = "current_database()" + else: + cluster_source = f"'{self._cluster}'" + + return """ + SELECT + {cluster_source} as cluster, + st.schemaname as schema, + st.relname as name, + pgtd.description as description, + att.attname as col_name, + pgtyp.typname as col_type, + pgcd.description as col_description, + att.attnum as col_sort_order + FROM pg_catalog.pg_attribute att + INNER JOIN + pg_catalog.pg_statio_all_tables as st + on att.attrelid=st.relid + LEFT JOIN + pg_catalog.pg_type pgtyp + on pgtyp.oid=att.atttypid + LEFT JOIN + pg_catalog.pg_description pgtd + on pgtd.objoid=st.relid and pgtd.objsubid=0 + LEFT JOIN + pg_catalog.pg_description pgcd + on pgcd.objoid=st.relid and pgcd.objsubid=att.attnum + WHERE att.attnum >=0 and {where_clause_suffix} + ORDER by cluster, schema, name, col_sort_order; + """.format( + cluster_source=cluster_source, + where_clause_suffix=where_clause_suffix, + ) + + def get_scope(self) -> str: + return 'extractor.postgres_metadata' diff --git a/databuilder/databuilder/extractor/presto_view_metadata_extractor.py b/databuilder/databuilder/extractor/presto_view_metadata_extractor.py new file mode 100644 index 0000000000..e328139258 --- /dev/null +++ b/databuilder/databuilder/extractor/presto_view_metadata_extractor.py @@ -0,0 +1,135 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import json +import logging +from typing import ( + Iterator, List, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +LOGGER = logging.getLogger(__name__) + + +class PrestoViewMetadataExtractor(Extractor): + """ + Extracts Presto View and column metadata from underlying meta store database using SQLAlchemyExtractor + PrestoViewMetadataExtractor does not require a separate table model but just reuse the existing TableMetadata + """ + # SQL statement to extract View metadata + # {where_clause_suffix} could be used to filter schemas + DEFAULT_SQL_STATEMENT = """ + SELECT t.TBL_ID, d.NAME as `schema`, t.TBL_NAME name, t.TBL_TYPE, t.VIEW_ORIGINAL_TEXT as view_original_text + FROM TBLS t + JOIN DBS d ON t.DB_ID = d.DB_ID + WHERE t.VIEW_EXPANDED_TEXT = '/* Presto View */' + {where_clause_suffix} + ORDER BY t.TBL_ID desc; + """ + + DEFAULT_POSTGRES_SQL_STATEMENT = """ + SELECT t."TBL_ID", + d."NAME" as "schema", + t."TBL_NAME" as name, + t."TBL_TYPE", + t."VIEW_ORIGINAL_TEXT" as view_original_text + FROM "TBLS" t + JOIN "DBS" d ON t."DB_ID" = d."DB_ID" + WHERE t."VIEW_EXPANDED_TEXT" = '/* Presto View */' + {where_clause_suffix} + ORDER BY t."TBL_ID" desc; + """ + + # Presto View data prefix and suffix definition: + # https://github.com/prestodb/presto/blob/43bd519052ba4c56ff1f4fc807075637ab5f4f10/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java#L153-L154 + PRESTO_VIEW_PREFIX = '/* Presto View: ' + PRESTO_VIEW_SUFFIX = ' */' + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster' + + DEFAULT_CONFIG = ConfigFactory.from_dict({WHERE_CLAUSE_SUFFIX_KEY: ' ', + CLUSTER_KEY: 'gold'}) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(PrestoViewMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(PrestoViewMetadataExtractor.CLUSTER_KEY) + + self.sql_stmt = self._choose_default_sql_stm(conf).format( + where_clause_suffix=conf.get_string(PrestoViewMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY)) + + LOGGER.info('SQL for hive metastore: %s', self.sql_stmt) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def _choose_default_sql_stm(self, conf: ConfigTree) -> str: + conn_string = conf.get_string("extractor.sqlalchemy.conn_string") + if conn_string.startswith('postgres') or conn_string.startswith('postgresql'): + return self.DEFAULT_POSTGRES_SQL_STATEMENT + else: + return self.DEFAULT_SQL_STATEMENT + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.presto_view_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + row = self._alchemy_extractor.extract() + while row: + columns = self._get_column_metadata(row['view_original_text']) + yield TableMetadata(database='presto', + cluster=self._cluster, + schema=row['schema'], + name=row['name'], + description=None, + columns=columns, + is_view=True) + row = self._alchemy_extractor.extract() + + def _get_column_metadata(self, + view_original_text: str) -> List[ColumnMetadata]: + """ + Get Column Metadata from VIEW_ORIGINAL_TEXT from TBLS table for Presto Views. + Columns are sorted the same way as they appear in Presto Create View SQL. + :param view_original_text: + :return: + """ + # remove encoded Presto View data prefix and suffix + encoded_view_info = ( + view_original_text. + split(PrestoViewMetadataExtractor.PRESTO_VIEW_PREFIX, 1)[-1]. + rsplit(PrestoViewMetadataExtractor.PRESTO_VIEW_SUFFIX, 1)[0] + ) + + # view_original_text is b64 encoded: + # https://github.com/prestodb/presto/blob/43bd519052ba4c56ff1f4fc807075637ab5f4f10/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java#L602-L605 + decoded_view_info = base64.b64decode(encoded_view_info) + columns = json.loads(decoded_view_info).get('columns') + + return [ColumnMetadata(name=column['name'], + description=None, + col_type=column['type'], + sort_order=i) for i, column in enumerate(columns)] diff --git a/databuilder/databuilder/extractor/redshift_metadata_extractor.py b/databuilder/databuilder/extractor/redshift_metadata_extractor.py new file mode 100644 index 0000000000..78cc46bff7 --- /dev/null +++ b/databuilder/databuilder/extractor/redshift_metadata_extractor.py @@ -0,0 +1,98 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( # noqa: F401 + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree # noqa: F401 + +from databuilder.extractor.base_postgres_metadata_extractor import BasePostgresMetadataExtractor + +LOGGER = logging.getLogger(__name__) + + +class RedshiftMetadataExtractor(BasePostgresMetadataExtractor): + """ + Extracts Redshift table and column metadata from underlying meta store database using SQLAlchemyExtractor + + + This differs from the PostgresMetadataExtractor because in order to support Redshift's late binding views, + we need to join the INFORMATION_SCHEMA data against the function PG_GET_LATE_BINDING_VIEW_COLS(). + """ + + def get_sql_statement(self, use_catalog_as_cluster_name: bool, where_clause_suffix: str) -> str: + if use_catalog_as_cluster_name: + cluster_source = "CURRENT_DATABASE()" + else: + cluster_source = f"'{self._cluster}'" + + if where_clause_suffix: + if where_clause_suffix.lower().startswith("where"): + LOGGER.warning("you no longer need to begin with 'where' in your suffix") + where_clause = where_clause_suffix + else: + where_clause = f"where {where_clause_suffix}" + else: + where_clause = "" + + return """ + SELECT + * + FROM ( + SELECT + {cluster_source} as cluster, + c.table_schema as schema, + c.table_name as name, + pgtd.description as description, + c.column_name as col_name, + c.data_type as col_type, + pgcd.description as col_description, + ordinal_position as col_sort_order + FROM INFORMATION_SCHEMA.COLUMNS c + INNER JOIN + pg_catalog.pg_statio_all_tables as st on c.table_schema=st.schemaname and c.table_name=st.relname + LEFT JOIN + pg_catalog.pg_description pgcd on pgcd.objoid=st.relid and pgcd.objsubid=c.ordinal_position + LEFT JOIN + pg_catalog.pg_description pgtd on pgtd.objoid=st.relid and pgtd.objsubid=0 + + UNION + + SELECT + {cluster_source} as cluster, + view_schema as schema, + view_name as name, + NULL as description, + column_name as col_name, + data_type as col_type, + NULL as col_description, + ordinal_position as col_sort_order + FROM + PG_GET_LATE_BINDING_VIEW_COLS() + COLS(view_schema NAME, view_name NAME, column_name NAME, data_type VARCHAR, ordinal_position INT) + + UNION + + SELECT + {cluster_source} AS cluster, + schemaname AS schema, + tablename AS name, + NULL AS description, + columnname AS col_name, + external_type AS col_type, + NULL AS col_description, + columnnum AS col_sort_order + FROM svv_external_columns + ) + + {where_clause} + ORDER by cluster, schema, name, col_sort_order ; + """.format( + cluster_source=cluster_source, + where_clause=where_clause, + ) + + def get_scope(self) -> str: + return 'extractor.redshift_metadata' diff --git a/databuilder/databuilder/extractor/restapi/__init__.py b/databuilder/databuilder/extractor/restapi/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/restapi/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/restapi/rest_api_extractor.py b/databuilder/databuilder/extractor/restapi/rest_api_extractor.py new file mode 100644 index 0000000000..65c22dc9f0 --- /dev/null +++ b/databuilder/databuilder/extractor/restapi/rest_api_extractor.py @@ -0,0 +1,70 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging +from typing import ( + Any, Dict, Iterator, Optional, +) + +from pyhocon import ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.rest_api.base_rest_api_query import BaseRestApiQuery + +REST_API_QUERY = 'restapi_query' +MODEL_CLASS = 'model_class' + +# Static record that will be added into extracted record +# For example, DashboardMetadata requires product name (static name) of Dashboard and REST api does not provide +# it. and you can add {'product': 'mode'} so that it will be included in the record. +STATIC_RECORD_DICT = 'static_record_dict' + +LOGGER = logging.getLogger(__name__) + + +class RestAPIExtractor(Extractor): + """ + An Extractor that calls one or more REST API to extract the data. + This extractor almost entirely depends on RestApiQuery. + """ + + def init(self, conf: ConfigTree) -> None: + + self._restapi_query: BaseRestApiQuery = conf.get(REST_API_QUERY) + self._iterator: Optional[Iterator[Dict[str, Any]]] = None + self._static_dict = conf.get(STATIC_RECORD_DICT, dict()) + LOGGER.info('static record: %s', self._static_dict) + + model_class = conf.get(MODEL_CLASS, None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + + def extract(self) -> Any: + """ + Fetch one result row from RestApiQuery, convert to {model_class} if specified before + returning. + :return: + """ + + if not self._iterator: + self._iterator = self._restapi_query.execute() + + try: + record = next(self._iterator) + except StopIteration: + return None + + if self._static_dict: + record.update(self._static_dict) + + if hasattr(self, 'model_class'): + return self.model_class(**record) + + return record + + def get_scope(self) -> str: + + return 'extractor.restapi' diff --git a/databuilder/databuilder/extractor/salesforce_extractor.py b/databuilder/databuilder/extractor/salesforce_extractor.py new file mode 100644 index 0000000000..de258cdff9 --- /dev/null +++ b/databuilder/databuilder/extractor/salesforce_extractor.py @@ -0,0 +1,102 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Dict, Iterator, List, Union, +) + +from pyhocon import ConfigTree +from simple_salesforce import Salesforce + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +LOGGER = logging.getLogger(__name__) + + +class SalesForceExtractor(Extractor): + """ + Extracts SalesForce objects + """ + + # CONFIG KEYS + CLUSTER_KEY = 'cluster_key' + SCHEMA_KEY = 'schema_key' + DATABASE_KEY = 'database_key' + OBJECT_NAMES_KEY = "object_names" + USERNAME_KEY = "username" + PASSWORD_KEY = "password" + SECURITY_TOKEN_KEY = "security_token" + + def init(self, conf: ConfigTree) -> None: + + self._cluster: str = conf.get_string(SalesForceExtractor.CLUSTER_KEY, "gold") + self._database: str = conf.get_string(SalesForceExtractor.DATABASE_KEY) + self._schema: str = conf.get_string(SalesForceExtractor.SCHEMA_KEY) + self._object_names: List[str] = conf.get_list(SalesForceExtractor.OBJECT_NAMES_KEY, []) + + self._client: Salesforce = Salesforce( + username=conf.get_string(SalesForceExtractor.USERNAME_KEY), + password=conf.get_string(SalesForceExtractor.PASSWORD_KEY), + security_token=conf.get_string(SalesForceExtractor.SECURITY_TOKEN_KEY), + ) + + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Extract the TableMetaData for each SalesForce Object + :return: + """ + + # Filter the sobjects if `OBJECT_NAMES_KEY` is set otherwise return all + sobjects = [ + sobject + for sobject in self._client.describe()["sobjects"] + if (len(self._object_names) == 0 or sobject["name"] in self._object_names) + ] + + for i, sobject in enumerate(sobjects): + object_name = sobject["name"] + logging.info( + f"({i+1}/{len(sobjects)}) Extracting SalesForce object ({object_name})" + ) + data = self._client.restful(path=f"sobjects/{object_name}/describe") + yield self._extract_table_metadata(object_name=object_name, data=data) + + def _extract_table_metadata( + self, object_name: str, data: Dict[str, Any] + ) -> TableMetadata: + # sort the fields by name because Amundsen requires a sort order for the columns and I did + # not see one in the response + fields = sorted(data["fields"], key=lambda x: x["name"]) + columns = [ + ColumnMetadata( + name=f["name"], + description=f["inlineHelpText"], + col_type=f["type"], + sort_order=i, + ) + for i, f in enumerate(fields) + ] + return TableMetadata( + database=self._database, + cluster=self._cluster, + schema=self._schema, + name=object_name, + # TODO: Can we extract table description / does it exist? + description=None, + columns=columns, + ) + + def get_scope(self) -> str: + return 'extractor.salesforce_metadata' diff --git a/databuilder/databuilder/extractor/snowflake_metadata_extractor.py b/databuilder/databuilder/extractor/snowflake_metadata_extractor.py new file mode 100644 index 0000000000..983feebf0b --- /dev/null +++ b/databuilder/databuilder/extractor/snowflake_metadata_extractor.py @@ -0,0 +1,162 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree +from text_unidecode import unidecode + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class SnowflakeMetadataExtractor(Extractor): + """ + Extracts Snowflake table and column metadata from underlying meta store database using SQLAlchemyExtractor. + Requirements: + snowflake-connector-python + snowflake-sqlalchemy + """ + # SELECT statement from snowflake information_schema to extract table and column metadata + # https://docs.snowflake.com/en/sql-reference/account-usage.html#label-account-usage-views + # This can be modified to use account_usage for performance at the cost of latency if necessary. + SQL_STATEMENT = """ + SELECT + lower(c.column_name) AS col_name, + c.comment AS col_description, + lower(c.data_type) AS col_type, + lower(c.ordinal_position) AS col_sort_order, + lower(c.table_catalog) AS database, + lower({cluster_source}) AS cluster, + lower(c.table_schema) AS schema, + lower(c.table_name) AS name, + t.comment AS description, + decode(lower(t.table_type), 'view', 'true', 'false') AS is_view + FROM + {database}.{schema}.COLUMNS AS c + LEFT JOIN + {database}.{schema}.TABLES t + ON c.TABLE_NAME = t.TABLE_NAME + AND c.TABLE_SCHEMA = t.TABLE_SCHEMA + {where_clause_suffix}; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + # Database Key, used to identify the database type in the UI. + DATABASE_KEY = 'database_key' + # Snowflake Database Key, used to determine which Snowflake database to connect to. + SNOWFLAKE_DATABASE_KEY = 'snowflake_database' + # Snowflake Schema Key, used to determine which Snowflake schema to use. + SNOWFLAKE_SCHEMA_KEY = 'snowflake_schema' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', + CLUSTER_KEY: DEFAULT_CLUSTER_NAME, + USE_CATALOG_AS_CLUSTER_NAME: True, + DATABASE_KEY: 'snowflake', + SNOWFLAKE_DATABASE_KEY: 'prod', + SNOWFLAKE_SCHEMA_KEY: 'INFORMATION_SCHEMA'} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(SnowflakeMetadataExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(SnowflakeMetadataExtractor.CLUSTER_KEY) + + if conf.get_bool(SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): + cluster_source = "c.table_catalog" + else: + cluster_source = f"'{self._cluster}'" + + self._database = conf.get_string(SnowflakeMetadataExtractor.DATABASE_KEY) + self._schema = conf.get_string(SnowflakeMetadataExtractor.DATABASE_KEY) + self._snowflake_database = conf.get_string(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY) + self._snowflake_schema = conf.get_string(SnowflakeMetadataExtractor.SNOWFLAKE_SCHEMA_KEY) + + self.sql_stmt = SnowflakeMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + cluster_source=cluster_source, + database=self._snowflake_database, + schema=self._snowflake_schema + ) + + LOGGER.info('SQL for snowflake metadata: %s', self.sql_stmt) + + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.snowflake' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata( + row['col_name'], + unidecode(row['col_description']) if row['col_description'] else None, + row['col_type'], + row['col_sort_order']) + ) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + unidecode(last_row['description']) if last_row['description'] else None, + columns, + last_row['is_view'] == 'true') + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/extractor/snowflake_table_last_updated_extractor.py b/databuilder/databuilder/extractor/snowflake_table_last_updated_extractor.py new file mode 100644 index 0000000000..33f6b9657d --- /dev/null +++ b/databuilder/databuilder/extractor/snowflake_table_last_updated_extractor.py @@ -0,0 +1,107 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Iterator, Union + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor import sql_alchemy_extractor +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.table_last_updated import TableLastUpdated + +LOGGER = logging.getLogger(__name__) + + +class SnowflakeTableLastUpdatedExtractor(Extractor): + """ + Extracts Snowflake table last update time from INFORMATION_SCHEMA metadata tables using SQLAlchemyExtractor. + Requirements: + snowflake-connector-python + snowflake-sqlalchemy + """ + # https://docs.snowflake.com/en/sql-reference/info-schema/views.html#columns + # 'last_altered' column in 'TABLES` metadata view under 'INFORMATION_SCHEMA' contains last time when the table was + # updated (both DML and DDL update). Below query fetches that column for each table. + SQL_STATEMENT = """ + SELECT + lower({cluster_source}) AS cluster, + lower(t.table_schema) AS schema, + lower(t.table_name) AS table_name, + DATE_PART(EPOCH, t.last_altered) AS last_updated_time + FROM + {database}.INFORMATION_SCHEMA.TABLES t + {where_clause_suffix}; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + # Database Key, used to identify the database type in the UI. + DATABASE_KEY = 'database_key' + # Snowflake Database Key, used to determine which Snowflake database to connect to. + SNOWFLAKE_DATABASE_KEY = 'snowflake_database' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' WHERE t.last_altered IS NOT NULL ', + CLUSTER_KEY: DEFAULT_CLUSTER_NAME, + USE_CATALOG_AS_CLUSTER_NAME: True, + DATABASE_KEY: 'snowflake', + SNOWFLAKE_DATABASE_KEY: 'prod'} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(SnowflakeTableLastUpdatedExtractor.DEFAULT_CONFIG) + self._cluster = conf.get_string(SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY) + + if conf.get_bool(SnowflakeTableLastUpdatedExtractor.USE_CATALOG_AS_CLUSTER_NAME): + cluster_source = "t.table_catalog" + else: + cluster_source = f"'{self._cluster}'" + + self._database = conf.get_string(SnowflakeTableLastUpdatedExtractor.DATABASE_KEY) + self._snowflake_database = conf.get_string(SnowflakeTableLastUpdatedExtractor.SNOWFLAKE_DATABASE_KEY) + + self.sql_stmt = SnowflakeTableLastUpdatedExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(SnowflakeTableLastUpdatedExtractor.WHERE_CLAUSE_SUFFIX_KEY), + cluster_source=cluster_source, + database=self._snowflake_database + ) + + LOGGER.info('SQL for snowflake table last updated timestamp: %s', self.sql_stmt) + + # use an sql_alchemy_extractor to execute sql + self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt) + self._extract_iter: Union[None, Iterator] = None + + def close(self) -> None: + if getattr(self, '_alchemy_extractor', None) is not None: + self._alchemy_extractor.close() + + def extract(self) -> Union[TableLastUpdated, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.snowflake_table_last_updated' + + def _get_extract_iter(self) -> Iterator[TableLastUpdated]: + """ + Provides iterator of result row from SQLAlchemy extractor + """ + tbl_last_updated_row = self._alchemy_extractor.extract() + while tbl_last_updated_row: + yield TableLastUpdated(table_name=tbl_last_updated_row['table_name'], + last_updated_time_epoch=tbl_last_updated_row['last_updated_time'], + schema=tbl_last_updated_row['schema'], + db=self._database, + cluster=tbl_last_updated_row['cluster']) + tbl_last_updated_row = self._alchemy_extractor.extract() diff --git a/databuilder/databuilder/extractor/sql_alchemy_extractor.py b/databuilder/databuilder/extractor/sql_alchemy_extractor.py new file mode 100644 index 0000000000..dfdec3935a --- /dev/null +++ b/databuilder/databuilder/extractor/sql_alchemy_extractor.py @@ -0,0 +1,107 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import Any + +from pyhocon import ConfigFactory, ConfigTree +from sqlalchemy import create_engine, text + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor + + +class SQLAlchemyExtractor(Extractor): + # Config keys + CONN_STRING = 'conn_string' + EXTRACT_SQL = 'extract_sql' + CONNECT_ARGS = 'connect_args' + """ + An Extractor that extracts records via SQLAlchemy. Database that supports SQLAlchemy can use this extractor + """ + + def init(self, conf: ConfigTree) -> None: + """ + Establish connections and import data model class if provided + :param conf: + """ + self.conf = conf + self.conn_string = conf.get_string(SQLAlchemyExtractor.CONN_STRING) + + self.connection = self._get_connection() + + self.extract_sql = conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + model_class = conf.get('model_class', None) + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + mod = importlib.import_module(module_name) + self.model_class = getattr(mod, class_name) + self._execute_query() + + def close(self) -> None: + if self.connection is not None: + self.connection.close() + + def _get_connection(self) -> Any: + """ + Create a SQLAlchemy connection to Database + """ + connect_args = { + k: v + for k, v in self.conf.get_config( + self.CONNECT_ARGS, default=ConfigTree() + ).items() + } + engine = create_engine(self.conn_string, connect_args=connect_args) + conn = engine.connect() + return conn + + def _execute_query(self) -> None: + """ + Create an iterator to execute sql. + """ + if not hasattr(self, 'results'): + results = self.connection.execute(text(self.extract_sql)) + # Makes this forward compatible with sqlalchemy >= 1.4 + if hasattr(results, "mappings"): + results = results.mappings() + self.results = results + + if hasattr(self, 'model_class'): + results = [self.model_class(**result) + for result in self.results] + else: + results = self.results + self.iter = iter(results) + + def extract(self) -> Any: + """ + Yield the sql result one at a time. + convert the result to model if a model_class is provided + """ + try: + return next(self.iter) + except StopIteration: + return None + except Exception as e: + raise e + + def get_scope(self) -> str: + return 'extractor.sqlalchemy' + + +def from_surrounding_config(conf: ConfigTree, sql_stmt: str) -> SQLAlchemyExtractor: + """ + A factory to create SQLAlchemyExtractors that are wrapped by another, specialized + extractor. This function pulls the config from the wrapping extractor's config, and + returns a newly configured SQLAlchemyExtractor. + :param conf: A config tree from which the sqlalchemy config still needs to be taken. + :param conf: The SQL statement to use for extraction. Expected to be set by the + wrapping extractor implementation, and not by the config. + """ + ae = SQLAlchemyExtractor() + c = Scoped.get_scoped_conf(conf, ae.get_scope()) \ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: sql_stmt})) + ae.init(c) + return ae diff --git a/databuilder/databuilder/extractor/table_metadata_constants.py b/databuilder/databuilder/extractor/table_metadata_constants.py new file mode 100644 index 0000000000..b2a08766e2 --- /dev/null +++ b/databuilder/databuilder/extractor/table_metadata_constants.py @@ -0,0 +1,5 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +# String for partition column badge +PARTITION_BADGE = 'partition column' diff --git a/databuilder/databuilder/extractor/teradata_metadata_extractor.py b/databuilder/databuilder/extractor/teradata_metadata_extractor.py new file mode 100644 index 0000000000..2978a6263b --- /dev/null +++ b/databuilder/databuilder/extractor/teradata_metadata_extractor.py @@ -0,0 +1,46 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( # noqa: F401 + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree # noqa: F401 + +from databuilder.extractor.base_teradata_metadata_extractor import BaseTeradataMetadataExtractor + + +class TeradataMetadataExtractor(BaseTeradataMetadataExtractor): + """ + Extracts Teradata table and column metadata from underlying meta store database using SQLAlchemyExtractor + """ + + def get_sql_statement( + self, use_catalog_as_cluster_name: bool, where_clause_suffix: str + ) -> str: + if use_catalog_as_cluster_name: + cluster_source = "current_database()" + else: + cluster_source = f"'{self._cluster}'" + + return """ + SELECT + {cluster_source} as td_cluster, + c.DatabaseName as schema, + c.TableName as name, + c.CommentString as description, + d.ColumnName as col_name, + d.ColumnType as col_type, + d.CommentString as col_description, + d.ColumnId as col_sort_order + FROM dbc.Tables c, dbc.Columns d + WHERE c.DatabaseName = d.DatabaseName AND c.TableName = d.TableName + AND {where_clause_suffix} + ORDER by cluster_a, schema, name, col_sort_order; + """.format( + cluster_source=cluster_source, + where_clause_suffix=where_clause_suffix, + ) + + def get_scope(self) -> str: + return "extractor.teradata_metadata" diff --git a/databuilder/databuilder/extractor/user/__init__.py b/databuilder/databuilder/extractor/user/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/user/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/user/bamboohr/__init__.py b/databuilder/databuilder/extractor/user/bamboohr/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/extractor/user/bamboohr/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/extractor/user/bamboohr/bamboohr_user_extractor.py b/databuilder/databuilder/extractor/user/bamboohr/bamboohr_user_extractor.py new file mode 100644 index 0000000000..8477144b04 --- /dev/null +++ b/databuilder/databuilder/extractor/user/bamboohr/bamboohr_user_extractor.py @@ -0,0 +1,64 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Iterator, Optional +from xml.etree import ElementTree + +import requests +from pyhocon import ConfigTree +from requests.auth import HTTPBasicAuth + +from databuilder.extractor.base_extractor import Extractor +from databuilder.models.user import User + + +class BamboohrUserExtractor(Extractor): + API_KEY = 'api_key' + SUBDOMAIN = 'subdomain' + + def init(self, conf: ConfigTree) -> None: + self._extract_iter: Optional[Iterator] = None + self._extract_iter = None + + self._api_key = conf.get_string(BamboohrUserExtractor.API_KEY) + self._subdomain = conf.get_string(BamboohrUserExtractor.SUBDOMAIN) + + def extract(self) -> Optional[User]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def _employee_directory_uri(self) -> str: + return f'https://api.bamboohr.com/api/gateway.php/{self._subdomain}/v1/employees/directory' + + def _get_extract_iter(self) -> Iterator[User]: + response = requests.get( + self._employee_directory_uri(), auth=HTTPBasicAuth(self._api_key, 'x') + ) + + root = ElementTree.fromstring(response.content) + + for user in root.findall('./employees/employee'): + + def get_field(name: str) -> str: + field = user.find(f"./field[@id='{name}']") + if field is not None and field.text is not None: + return field.text + else: + return '' + + yield User( + email=get_field('workEmail'), + first_name=get_field('firstName'), + last_name=get_field('lastName'), + name=get_field('displayName'), + team_name=get_field('department'), + role_name=get_field('jobTitle'), + ) + + def get_scope(self) -> str: + return 'extractor.bamboohr_user' diff --git a/databuilder/databuilder/extractor/vertica_metadata_extractor.py b/databuilder/databuilder/extractor/vertica_metadata_extractor.py new file mode 100644 index 0000000000..43ef7810c0 --- /dev/null +++ b/databuilder/databuilder/extractor/vertica_metadata_extractor.py @@ -0,0 +1,140 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import namedtuple +from itertools import groupby +from typing import ( + Any, Dict, Iterator, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +TableKey = namedtuple('TableKey', ['schema', 'table_name']) + +LOGGER = logging.getLogger(__name__) + + +class VerticaMetadataExtractor(Extractor): + """ + Extracts vertica table and column metadata from underlying meta store database using SQLAlchemyExtractor + V_CATALOG does not have table and column description columns + CLUSTER_KEY config parameter is used as cluster name + Not distinguishing between table & view for now + """ + # SELECT statement from vertica information_schema to extract table and column metadata + SQL_STATEMENT = """ + SELECT + lower(c.column_name) AS col_name, + lower(c.data_type) AS col_type, + c.ordinal_position AS col_sort_order, + {cluster_source} AS cluster, + lower(c.table_schema) AS "schema", + lower(c.table_name) AS name, + False AS is_view + FROM + V_CATALOG.COLUMNS AS c + LEFT JOIN + V_CATALOG.TABLES t + ON c.TABLE_NAME = t.TABLE_NAME + AND c.TABLE_SCHEMA = t.TABLE_SCHEMA + {where_clause_suffix} + ORDER by cluster, "schema", name, col_sort_order ; + """ + + # CONFIG KEYS + WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' + CLUSTER_KEY = 'cluster_key' + USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' + DATABASE_KEY = 'database_key' + + # Default values + DEFAULT_CLUSTER_NAME = 'master' + + DEFAULT_CONFIG = ConfigFactory.from_dict( + {WHERE_CLAUSE_SUFFIX_KEY: ' ', CLUSTER_KEY: DEFAULT_CLUSTER_NAME, USE_CATALOG_AS_CLUSTER_NAME: False} + ) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(VerticaMetadataExtractor.DEFAULT_CONFIG) + self._cluster = '{}'.format(conf.get_string(VerticaMetadataExtractor.CLUSTER_KEY)) + + if conf.get_bool(VerticaMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): + cluster_source = "c.table_catalog" + else: + cluster_source = "'{}'".format(self._cluster) + + self._database = conf.get_string(VerticaMetadataExtractor.DATABASE_KEY, default='vertica') + + self.sql_stmt = VerticaMetadataExtractor.SQL_STATEMENT.format( + where_clause_suffix=conf.get_string(VerticaMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), + cluster_source=cluster_source + ) + + self._alchemy_extractor = SQLAlchemyExtractor() + sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\ + .with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})) + + self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) + + LOGGER.info('SQL for vertica metadata: {}'.format(self.sql_stmt)) + + self._alchemy_extractor.init(sql_alch_conf) + self._extract_iter: Union[None, Iterator] = None + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.vertica_metadata' + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + """ + Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata + :return: + """ + for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key): + columns = [] + + for row in group: + last_row = row + columns.append(ColumnMetadata(row['col_name'], None, + row['col_type'], row['col_sort_order'])) + + yield TableMetadata(self._database, last_row['cluster'], + last_row['schema'], + last_row['name'], + None, + columns, + is_view=last_row['is_view']) + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + """ + Provides iterator of result row from SQLAlchemy extractor + :return: + """ + row = self._alchemy_extractor.extract() + while row: + yield row + row = self._alchemy_extractor.extract() + + def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]: + """ + Table key consists of schema and table name + :param row: + :return: + """ + if row: + return TableKey(schema=row['schema'], table_name=row['name']) + + return None diff --git a/databuilder/databuilder/filesystem/__init__.py b/databuilder/databuilder/filesystem/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/filesystem/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/filesystem/filesystem.py b/databuilder/databuilder/filesystem/filesystem.py new file mode 100644 index 0000000000..3c6e147df9 --- /dev/null +++ b/databuilder/databuilder/filesystem/filesystem.py @@ -0,0 +1,110 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import List + +from pyhocon import ConfigFactory, ConfigTree +from retrying import retry + +from databuilder import Scoped +from databuilder.filesystem.metadata import FileMetadata + +LOGGER = logging.getLogger(__name__) +CLIENT_ERRORS = {'ClientError', 'FileNotFoundError', 'ParamValidationError'} + + +def is_client_side_error(e: Exception) -> bool: + """ + An method that determines if the error is client side error within FileSystem context + :param e: + :return: + """ + return e.__class__.__name__ in CLIENT_ERRORS + + +def is_retriable_error(e: Exception) -> bool: + """ + An method that determines if the error is retriable error within FileSystem context + :param e: + :return: + """ + + return not is_client_side_error(e) + + +class FileSystem(Scoped): + """ + An high level file system, that utilizes Dask File system. + http://docs.dask.org/en/latest/remote-data-services.html + + All remote call leverages retry against any failure. https://pypi.org/project/retrying/ + """ + + # METADATA KEYS + LAST_UPDATED = 'last_updated' + SIZE = 'size' + + # CONFIG KEYS + DASK_FILE_SYSTEM = 'dask_file_system' + + # File metadata that is provided via info(path) method on Dask file system provides a dictionary. As dictionary + # does not guarantee same key across different implementation, user can provide key mapping. + FILE_METADATA_MAPPING_KEY = 'file_metadata_mapping' + + default_metadata_mapping = {LAST_UPDATED: 'LastModified', + SIZE: 'Size'} + DEFAULT_CONFIG = ConfigFactory.from_dict({FILE_METADATA_MAPPING_KEY: default_metadata_mapping}) + + def init(self, + conf: ConfigTree + ) -> None: + """ + Initialize Filesystem with DASK file system instance + Dask file system supports multiple remote storage such as S3, HDFS, Google cloud storage, + Azure Datalake, etc + + http://docs.dask.org/en/latest/remote-data-services.html + https://github.com/dask/s3fs + https://github.com/dask/hdfs3 + ... + + :param conf: hocon config + :return: + """ + self._conf = conf.with_fallback(FileSystem.DEFAULT_CONFIG) + self._dask_fs = self._conf.get(FileSystem.DASK_FILE_SYSTEM) + self._metadata_key_mapping = self._conf.get(FileSystem.FILE_METADATA_MAPPING_KEY).as_plain_ordered_dict() + + @retry(retry_on_exception=is_retriable_error, stop_max_attempt_number=3, wait_exponential_multiplier=1000, + wait_exponential_max=5000) + def ls(self, path: str) -> List[str]: + """ + A scope for the config. Typesafe config supports nested config. + Scope, string, is used to basically peel off nested config + :return: + """ + return self._dask_fs.ls(path) + + @retry(retry_on_exception=is_retriable_error, stop_max_attempt_number=3, wait_exponential_multiplier=1000, + wait_exponential_max=5000) + def is_file(self, path: str) -> bool: + contents = self._dask_fs.ls(path) + return len(contents) == 1 and contents[0] == path + + @retry(retry_on_exception=is_retriable_error, stop_max_attempt_number=3, wait_exponential_multiplier=1000, + wait_exponential_max=5000) + def info(self, path: str) -> FileMetadata: + """ + Metadata information about the file. It utilizes _metadata_key_mapping when fetching metadata so that it can + deal with different keys + :return: + """ + metadata_dict = self._dask_fs.info(path) + fm = FileMetadata(path=path, + last_updated=metadata_dict[self._metadata_key_mapping[FileSystem.LAST_UPDATED]], + size=metadata_dict[self._metadata_key_mapping[FileSystem.SIZE]]) + return fm + + def get_scope(self) -> str: + return 'filesystem' diff --git a/databuilder/databuilder/filesystem/metadata.py b/databuilder/databuilder/filesystem/metadata.py new file mode 100644 index 0000000000..3ebed1e210 --- /dev/null +++ b/databuilder/databuilder/filesystem/metadata.py @@ -0,0 +1,19 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime + + +class FileMetadata(object): + + def __init__(self, + path: str, + last_updated: datetime, + size: int + ) -> None: + self.path = path + self.last_updated = last_updated + self.size = size + + def __repr__(self) -> str: + return f'FileMetadata(path={self.path!r}, last_updated={self.last_updated!r}, size={self.size!r})' diff --git a/databuilder/databuilder/job/__init__.py b/databuilder/databuilder/job/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/job/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/job/base_job.py b/databuilder/databuilder/job/base_job.py new file mode 100644 index 0000000000..0fca300a07 --- /dev/null +++ b/databuilder/databuilder/job/base_job.py @@ -0,0 +1,31 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc + +from pyhocon import ConfigTree + +from databuilder import Scoped +from databuilder.utils.closer import Closer + + +class Job(Scoped): + closer = Closer() + + """ + A Databuilder job that represents single work unit. + """ + @abc.abstractmethod + def init(self, conf: ConfigTree) -> None: + pass + + @abc.abstractmethod + def launch(self) -> None: + """ + Launch a job + :return: None + """ + pass + + def get_scope(self) -> str: + return 'job' diff --git a/databuilder/databuilder/job/job.py b/databuilder/databuilder/job/job.py new file mode 100644 index 0000000000..b3791a395e --- /dev/null +++ b/databuilder/databuilder/job/job.py @@ -0,0 +1,89 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from pyhocon import ConfigTree +from statsd import StatsClient + +from databuilder import Scoped +from databuilder.job.base_job import Job +from databuilder.publisher.base_publisher import NoopPublisher, Publisher +from databuilder.task.base_task import Task + +LOGGER = logging.getLogger(__name__) + + +class DefaultJob(Job): + # Config keys + IS_STATSD_ENABLED = 'is_statsd_enabled' + JOB_IDENTIFIER = 'identifier' + + """ + Default job that expects a task, and optional publisher + If configured job will emit success/fail metric counter through statsd where prefix will be + amundsen.databuilder.job.[identifier] . + Note that job.identifier is part of metrics prefix and choose unique & readable identifier for the job. + + To configure statsd itself, use environment variable: https://statsd.readthedocs.io/en/v3.2.1/configure.html + """ + + def __init__(self, + conf: ConfigTree, + task: Task, + publisher: Publisher = NoopPublisher()) -> None: + self.task = task + self.conf = conf + self.publisher = publisher + self.scoped_conf = Scoped.get_scoped_conf(self.conf, + self.get_scope()) + if self.scoped_conf.get_bool(DefaultJob.IS_STATSD_ENABLED, False): + prefix = f'amundsen.databuilder.job.{self.scoped_conf.get_string(DefaultJob.JOB_IDENTIFIER)}' + LOGGER.info('Setting statsd for job metrics with prefix: %s', prefix) + self.statsd = StatsClient(prefix=prefix) + else: + self.statsd = None + + def init(self, conf: ConfigTree) -> None: + pass + + def _init(self) -> None: + self.task.init(self.conf) + + def launch(self) -> None: + """ + Launch a job by initializing job, run task and publish. + :return: + """ + + logging.info('Launching a job') + # Using nested try finally to make sure task get closed as soon as possible as well as to guarantee all the + # closeable get closed. + try: + is_success = True + self._init() + try: + self.task.run() + finally: + self.task.close() + + self.publisher.init(Scoped.get_scoped_conf(self.conf, self.publisher.get_scope())) + Job.closer.register(self.publisher.close) + self.publisher.publish() + + except Exception as e: + is_success = False + raise e + finally: + # TODO: If more metrics are needed on different construct, such as task, consider abstracting this out + if self.statsd: + if is_success: + LOGGER.info('Publishing job metrics for success') + self.statsd.incr('success') + else: + LOGGER.info('Publishing job metrics for failure') + self.statsd.incr('fail') + + Job.closer.close() + + logging.info('Job completed') diff --git a/databuilder/databuilder/loader/__init__.py b/databuilder/databuilder/loader/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/loader/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/loader/base_loader.py b/databuilder/databuilder/loader/base_loader.py new file mode 100644 index 0000000000..38562e7006 --- /dev/null +++ b/databuilder/databuilder/loader/base_loader.py @@ -0,0 +1,25 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Any + +from pyhocon import ConfigTree + +from databuilder import Scoped + + +class Loader(Scoped): + """ + A loader loads to the destination or to the staging area + """ + @abc.abstractmethod + def init(self, conf: ConfigTree) -> None: + pass + + @abc.abstractmethod + def load(self, record: Any) -> None: + pass + + def get_scope(self) -> str: + return 'loader' diff --git a/databuilder/databuilder/loader/file_system_atlas_csv_loader.py b/databuilder/databuilder/loader/file_system_atlas_csv_loader.py new file mode 100644 index 0000000000..88c6897611 --- /dev/null +++ b/databuilder/databuilder/loader/file_system_atlas_csv_loader.py @@ -0,0 +1,201 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import logging +import os +import shutil +from csv import DictWriter +from typing import ( + Any, Dict, FrozenSet, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.base_loader import Loader +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.serializers import atlas_serializer +from databuilder.utils.closer import Closer + +LOGGER = logging.getLogger(__name__) + + +class FsAtlasCSVLoader(Loader): + """ + Write entity and relationship CSV file(s) that can be consumed by + AtlasCsvPublisher. + It assumes that the record it consumes is instance of AtlasCsvSerializable + """ + # Config keys + ENTITY_DIR_PATH = 'entity_dir_path' + RELATIONSHIP_DIR_PATH = 'relationship_dir_path' + FORCE_CREATE_DIR = 'force_create_directory' + SHOULD_DELETE_CREATED_DIR = 'delete_created_directories' + + _DEFAULT_CONFIG = ConfigFactory.from_dict({ + SHOULD_DELETE_CREATED_DIR: True, + FORCE_CREATE_DIR: False, + }) + + def __init__(self) -> None: + self._entity_file_mapping: Dict[Any, DictWriter] = {} + self._relation_file_mapping: Dict[Any, DictWriter] = {} + self._keys: Dict[FrozenSet[str], int] = {} + self._closer = Closer() + + def init(self, conf: ConfigTree) -> None: + """ + Initializing FsAtlasCSVLoader by creating directory for entity files + and relationship files. Note that the directory defined in + configuration should not exist. + :param conf: + :return: + """ + conf = conf.with_fallback(FsAtlasCSVLoader._DEFAULT_CONFIG) + + self._entity_dir = conf.get_string(FsAtlasCSVLoader.ENTITY_DIR_PATH) + self._relation_dir = \ + conf.get_string(FsAtlasCSVLoader.RELATIONSHIP_DIR_PATH) + + self._delete_created_dir = \ + conf.get_bool(FsAtlasCSVLoader.SHOULD_DELETE_CREATED_DIR) + self._force_create_dir = conf.get_bool(FsAtlasCSVLoader.FORCE_CREATE_DIR) + self._create_directory(self._entity_dir) + self._create_directory(self._relation_dir) + + def _create_directory(self, path: str) -> None: + """ + Validate directory does not exist, creates it, register deletion of + created directory function to Job.closer. + :param path: + :return: + """ + if os.path.exists(path): + if self._force_create_dir: + LOGGER.info('Directory exist. Deleting directory %s', path) + shutil.rmtree(path) + else: + raise RuntimeError(f'Directory should not exist: {path}') + + os.makedirs(path) + + def _delete_dir() -> None: + if not self._delete_created_dir: + LOGGER.warning('Skip Deleting directory %s', path) + return + + LOGGER.info('Deleting directory %s', path) + shutil.rmtree(path) + + # Directory should be deleted after publish is finished + Job.closer.register(_delete_dir) + + def load(self, csv_serializable: AtlasSerializable) -> None: + """ + Writes AtlasSerializable into CSV files. + There are multiple CSV files that this method writes. + This is because there're not only node and relationship, but also it + can also have different entities, and relationships. + + Common pattern for both entities and relations: + 1. retrieve csv row (a dict where keys represent a header, + values represent a row) + 2. using this dict to get a appropriate csv writer and write to it. + 3. repeat 1 and 2 + + :param csv_serializable: + :return: + """ + + entity = csv_serializable.next_atlas_entity() + while entity: + entity_dict = atlas_serializer.serialize_entity(entity) + key = (self._make_key(entity_dict), entity.typeName) + file_suffix = '{}_{}'.format(*key) + entity_writer = self._get_writer( + entity_dict, + self._entity_file_mapping, + key, + self._entity_dir, + file_suffix, + ) + entity_writer.writerow(entity_dict) + entity = csv_serializable.next_atlas_entity() + + relation = csv_serializable.next_atlas_relation() + while relation: + relation_dict = atlas_serializer.serialize_relationship(relation) + keys = ( + self._make_key(relation_dict), + relation.entityType1, + relation.entityType2, + ) + + file_suffix = '{}_{}_{}'.format(*keys) + relation_writer = self._get_writer( + relation_dict, + self._relation_file_mapping, + keys, + self._relation_dir, + file_suffix, + ) + relation_writer.writerow(relation_dict) + relation = csv_serializable.next_atlas_relation() + + def _get_writer( + self, + csv_record_dict: Dict[str, Any], + file_mapping: Dict[Any, DictWriter], + key: Any, + dir_path: str, + file_suffix: str, + ) -> DictWriter: + """ + Finds a writer based on csv record, key. + If writer does not exist, it's creates a csv writer and update the + mapping. + + :param csv_record_dict: + :param file_mapping: + :param key: + :param file_suffix: + :return: + """ + writer = file_mapping.get(key) + if writer: + return writer + + LOGGER.info('Creating file for %s', key) + + file_out = open(f'{dir_path}/{file_suffix}.csv', 'w', encoding='utf8') + writer = csv.DictWriter( # type: ignore + file_out, + fieldnames=csv_record_dict.keys(), + quoting=csv.QUOTE_NONNUMERIC, + ) + + def file_out_close() -> None: + LOGGER.info('Closing file IO %s', file_out) + file_out.close() + + self._closer.register(file_out_close) + + writer.writeheader() + file_mapping[key] = writer + + return writer + + def close(self) -> None: + """ + Any closeable callable registered in _closer, it will close. + :return: + """ + self._closer.close() + + def get_scope(self) -> str: + return "loader.filesystem_csv_atlas" + + def _make_key(self, record_dict: Dict[str, Any]) -> str: + """ Each unique set of record keys is assigned an increasing numeric key """ + return str(self._keys.setdefault(frozenset(record_dict.keys()), len(self._keys))).rjust(3, '0') diff --git a/databuilder/databuilder/loader/file_system_csv_loader.py b/databuilder/databuilder/loader/file_system_csv_loader.py new file mode 100644 index 0000000000..ac6d8794e7 --- /dev/null +++ b/databuilder/databuilder/loader/file_system_csv_loader.py @@ -0,0 +1,60 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import logging +from typing import Any + +from pyhocon import ConfigTree + +from databuilder.loader.base_loader import Loader + +LOGGER = logging.getLogger(__name__) + + +class FileSystemCSVLoader(Loader): + """ + Loader class to write csv files to Local FileSystem + """ + + def init(self, conf: ConfigTree) -> None: + """ + Initialize file handlers from conf + :param conf: + """ + self.conf = conf + self.file_path = self.conf.get_string('file_path') + self.file_mode = self.conf.get_string('mode', 'w') + + self.file_handler = open(self.file_path, self.file_mode) + + def load(self, record: Any) -> None: + """ + Write record object as csv to file + :param record: + :return: + """ + if not record: + return + + if not hasattr(self, 'writer'): + self.writer = csv.DictWriter(self.file_handler, + fieldnames=vars(record).keys()) + self.writer.writeheader() + + self.writer.writerow(vars(record)) + self.file_handler.flush() + + def close(self) -> None: + """ + Close file handlers + :return: + """ + try: + if self.file_handler: + self.file_handler.close() + except Exception as e: + LOGGER.warning("Failed trying to close a file handler! %s", e) + + def get_scope(self) -> str: + return "loader.filesystem.csv" diff --git a/databuilder/databuilder/loader/file_system_elasticsearch_json_loader.py b/databuilder/databuilder/loader/file_system_elasticsearch_json_loader.py new file mode 100644 index 0000000000..142e78b925 --- /dev/null +++ b/databuilder/databuilder/loader/file_system_elasticsearch_json_loader.py @@ -0,0 +1,68 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os + +from pyhocon import ConfigTree + +from databuilder.loader.base_loader import Loader +from databuilder.models.elasticsearch_document import ElasticsearchDocument + + +class FSElasticsearchJSONLoader(Loader): + """ + Loader class to produce Elasticsearch bulk load file to Local FileSystem + """ + FILE_PATH_CONFIG_KEY = 'file_path' + FILE_MODE_CONFIG_KEY = 'mode' + + def init(self, conf: ConfigTree) -> None: + """ + + :param conf: + :return: + """ + self.conf = conf + self.file_path = self.conf.get_string(FSElasticsearchJSONLoader.FILE_PATH_CONFIG_KEY) + self.file_mode = self.conf.get_string(FSElasticsearchJSONLoader.FILE_MODE_CONFIG_KEY, 'w') + + file_dir = self.file_path.rsplit('/', 1)[0] + self._ensure_directory_exists(file_dir) + self.file_handler = open(self.file_path, self.file_mode) + + def _ensure_directory_exists(self, path: str) -> None: + """ + Check to ensure file directory exists; create the directories otherwise + :param path: + :return: None + """ + if os.path.exists(path): + return # nothing to do here + + os.makedirs(path) + + def load(self, record: ElasticsearchDocument) -> None: + """ + Write a record in json format to file + :param record: + :return: + """ + if not record: + return + + if not isinstance(record, ElasticsearchDocument): + raise Exception("Record not of type 'ElasticsearchDocument'!") + + self.file_handler.write(record.to_json()) + self.file_handler.flush() + + def close(self) -> None: + """ + close the file handler + :return: + """ + if self.file_handler: + self.file_handler.close() + + def get_scope(self) -> str: + return 'loader.filesystem.elasticsearch' diff --git a/databuilder/databuilder/loader/file_system_mysql_csv_loader.py b/databuilder/databuilder/loader/file_system_mysql_csv_loader.py new file mode 100644 index 0000000000..1214225b91 --- /dev/null +++ b/databuilder/databuilder/loader/file_system_mysql_csv_loader.py @@ -0,0 +1,163 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import logging +import os +import shutil +from csv import DictWriter +from typing import ( + Any, Dict, FrozenSet, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.base_loader import Loader +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers import mysql_serializer +from databuilder.utils.closer import Closer + +LOGGER = logging.getLogger(__name__) + + +class FSMySQLCSVLoader(Loader): + """ + Write table record CSV file(s) that can be consumed by MySQLCsvPublisher. + It assumes that the record it consumes is instance of TableSerializable. + """ + # Config keys + RECORD_DIR_PATH = 'record_dir_path' + FORCE_CREATE_DIR = 'force_create_directory' + SHOULD_DELETE_CREATED_DIR = 'delete_created_directories' + + _DEFAULT_CONFIG = ConfigFactory.from_dict({ + SHOULD_DELETE_CREATED_DIR: True, + FORCE_CREATE_DIR: False + }) + + def __init__(self) -> None: + self._record_file_mapping: Dict[Any, DictWriter] = {} + self._keys: Dict[FrozenSet[str], int] = {} + self._closer = Closer() + + def init(self, conf: ConfigTree) -> None: + """ + Initializing FsMySQLCSVLoader by creating directory for record files. + Note that the directory defined in configuration should not exist. + :param conf: + :return: + """ + conf = conf.with_fallback(FSMySQLCSVLoader._DEFAULT_CONFIG) + + self._record_dir = conf.get_string(FSMySQLCSVLoader.RECORD_DIR_PATH) + self._delete_created_dir = conf.get_bool(FSMySQLCSVLoader.SHOULD_DELETE_CREATED_DIR) + self._force_create_dir = conf.get_bool(FSMySQLCSVLoader.FORCE_CREATE_DIR) + self._create_directory(self._record_dir) + + def _create_directory(self, path: str) -> None: + """ + Validate directory does not exist, creates it, register deletion of + created directory function to Job.closer. + :param path: + :return: + """ + if os.path.exists(path): + if self._force_create_dir: + LOGGER.info(f'Directory exist. Deleting directory {path}') + shutil.rmtree(path) + else: + raise RuntimeError(f'Directory should not exist: {path}') + + os.makedirs(path) + + def _delete_dir() -> None: + if not self._delete_created_dir: + LOGGER.warning(f'Skip Deleting directory {path}') + return + + LOGGER.info(f'Deleting directory {path}') + shutil.rmtree(path) + + # Directory should be deleted after publish is finished + Job.closer.register(_delete_dir) + + def load(self, csv_serializable: TableSerializable) -> None: + """ + Writes TableSerializable records into CSV files. + There are multiple CSV files meaning different tables that this method writes. + + Common pattern for table records: + 1. retrieve csv row (a dict where keys represent a header, + values represent a row) + 2. using this dict to get a appropriate csv writer and write to it. + 3. repeat 1 and 2 + + :param csv_serializable: + :return: + """ + record = csv_serializable.next_record() + while record: + record_dict = mysql_serializer.serialize_record(record) + table_name = record.__tablename__ + key = (table_name, self._make_key(record_dict)) + file_suffix = '{}_{}'.format(*key) + record_writer = self._get_writer(record_dict, + self._record_file_mapping, + key, + self._record_dir, + file_suffix) + record_writer.writerow(record_dict) + record = csv_serializable.next_record() + + def _get_writer(self, + csv_record_dict: Dict[str, Any], + file_mapping: Dict[Any, DictWriter], + key: Any, + dir_path: str, + file_suffix: str + ) -> DictWriter: + """ + Finds a writer based on csv record, key. + If writer does not exist, it's creates a csv writer and update the mapping. + + :param csv_record_dict: + :param file_mapping: + :param key: + :param dir_path: + :param file_suffix: + :return: + """ + writer = file_mapping.get(key) + if writer: + return writer + + LOGGER.info(f'Creating file for {key}') + + file_out = open(f'{dir_path}/{file_suffix}.csv', 'w', encoding='utf8') + writer = csv.DictWriter(file_out, fieldnames=csv_record_dict.keys(), + quoting=csv.QUOTE_NONNUMERIC) + + def file_out_close() -> None: + LOGGER.info(f'Closing file IO {file_out}') + file_out.close() + self._closer.register(file_out_close) + + writer.writeheader() + file_mapping[key] = writer + + return writer + + def close(self) -> None: + """ + Any closeable callable registered in _closer, it will close. + :return: + """ + self._closer.close() + + def get_scope(self) -> str: + return "loader.mysql_filesystem_csv" + + def _make_key(self, record_dict: Dict[str, Any]) -> int: + """ Each unique set of record keys is assigned an increasing numeric key """ + return self._keys.setdefault(frozenset(record_dict.keys()), len(self._keys)) diff --git a/databuilder/databuilder/loader/file_system_neo4j_csv_loader.py b/databuilder/databuilder/loader/file_system_neo4j_csv_loader.py new file mode 100644 index 0000000000..31106d34bf --- /dev/null +++ b/databuilder/databuilder/loader/file_system_neo4j_csv_loader.py @@ -0,0 +1,191 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import logging +import os +import shutil +from csv import DictWriter +from typing import ( + Any, Dict, FrozenSet, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.base_loader import Loader +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.serializers import neo4_serializer +from databuilder.utils.closer import Closer + +LOGGER = logging.getLogger(__name__) + + +class FsNeo4jCSVLoader(Loader): + """ + Write node and relationship CSV file(s) that can be consumed by + Neo4jCsvPublisher. + It assumes that the record it consumes is instance of Neo4jCsvSerializable + """ + # Config keys + NODE_DIR_PATH = 'node_dir_path' + RELATION_DIR_PATH = 'relationship_dir_path' + FORCE_CREATE_DIR = 'force_create_directory' + SHOULD_DELETE_CREATED_DIR = 'delete_created_directories' + + _DEFAULT_CONFIG = ConfigFactory.from_dict({ + SHOULD_DELETE_CREATED_DIR: True, + FORCE_CREATE_DIR: False + }) + + def __init__(self) -> None: + self._node_file_mapping: Dict[Any, DictWriter] = {} + self._relation_file_mapping: Dict[Any, DictWriter] = {} + self._keys: Dict[FrozenSet[str], int] = {} + self._closer = Closer() + + def init(self, conf: ConfigTree) -> None: + """ + Initializing FsNeo4jCsvLoader by creating directory for node files + and relationship files. Note that the directory defined in + configuration should not exist. + :param conf: + :return: + """ + conf = conf.with_fallback(FsNeo4jCSVLoader._DEFAULT_CONFIG) + + self._node_dir = conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH) + self._relation_dir = \ + conf.get_string(FsNeo4jCSVLoader.RELATION_DIR_PATH) + + self._delete_created_dir = \ + conf.get_bool(FsNeo4jCSVLoader.SHOULD_DELETE_CREATED_DIR) + self._force_create_dir = conf.get_bool(FsNeo4jCSVLoader.FORCE_CREATE_DIR) + self._create_directory(self._node_dir) + self._create_directory(self._relation_dir) + + def _create_directory(self, path: str) -> None: + """ + Validate directory does not exist, creates it, register deletion of + created directory function to Job.closer. + :param path: + :return: + """ + if os.path.exists(path): + if self._force_create_dir: + LOGGER.info('Directory exist. Deleting directory %s', path) + shutil.rmtree(path) + else: + raise RuntimeError(f'Directory should not exist: {path}') + + os.makedirs(path) + + def _delete_dir() -> None: + if not self._delete_created_dir: + LOGGER.warning('Skip Deleting directory %s', path) + return + + LOGGER.info('Deleting directory %s', path) + shutil.rmtree(path) + + # Directory should be deleted after publish is finished + Job.closer.register(_delete_dir) + + def load(self, csv_serializable: GraphSerializable) -> None: + """ + Writes Neo4jCsvSerializable into CSV files. + There are multiple CSV files that this method writes. + This is because there're not only node and relationship, but also it + can also have different nodes, and relationships. + + Common pattern for both nodes and relations: + 1. retrieve csv row (a dict where keys represent a header, + values represent a row) + 2. using this dict to get a appropriate csv writer and write to it. + 3. repeat 1 and 2 + + :param csv_serializable: + :return: + """ + + node = csv_serializable.next_node() + while node: + node_dict = neo4_serializer.serialize_node(node) + key = (node.label, self._make_key(node_dict)) + file_suffix = '{}_{}'.format(*key) + node_writer = self._get_writer(node_dict, + self._node_file_mapping, + key, + self._node_dir, + file_suffix) + node_writer.writerow(node_dict) + node = csv_serializable.next_node() + + relation = csv_serializable.next_relation() + while relation: + relation_dict = neo4_serializer.serialize_relationship(relation) + key2 = (relation.start_label, + relation.end_label, + relation.type, + self._make_key(relation_dict)) + + file_suffix = f'{key2[0]}_{key2[1]}_{key2[2]}_{key2[3]}' + relation_writer = self._get_writer(relation_dict, + self._relation_file_mapping, + key2, + self._relation_dir, + file_suffix) + relation_writer.writerow(relation_dict) + relation = csv_serializable.next_relation() + + def _get_writer(self, + csv_record_dict: Dict[str, Any], + file_mapping: Dict[Any, DictWriter], + key: Any, + dir_path: str, + file_suffix: str + ) -> DictWriter: + """ + Finds a writer based on csv record, key. + If writer does not exist, it's creates a csv writer and update the + mapping. + + :param csv_record_dict: + :param file_mapping: + :param key: + :param file_suffix: + :return: + """ + writer = file_mapping.get(key) + if writer: + return writer + + LOGGER.info('Creating file for %s', key) + + file_out = open(f'{dir_path}/{file_suffix}.csv', 'w', encoding='utf8') + writer = csv.DictWriter(file_out, fieldnames=csv_record_dict.keys(), + quoting=csv.QUOTE_NONNUMERIC) + + def file_out_close() -> None: + LOGGER.info('Closing file IO %s', file_out) + file_out.close() + self._closer.register(file_out_close) + + writer.writeheader() + file_mapping[key] = writer + + return writer + + def close(self) -> None: + """ + Any closeable callable registered in _closer, it will close. + :return: + """ + self._closer.close() + + def get_scope(self) -> str: + return "loader.filesystem_csv_neo4j" + + def _make_key(self, record_dict: Dict[str, Any]) -> int: + """ Each unique set of record keys is assigned an increasing numeric key """ + return self._keys.setdefault(frozenset(record_dict.keys()), len(self._keys)) diff --git a/databuilder/databuilder/loader/file_system_neptune_csv_loader.py b/databuilder/databuilder/loader/file_system_neptune_csv_loader.py new file mode 100644 index 0000000000..f81e16b1f6 --- /dev/null +++ b/databuilder/databuilder/loader/file_system_neptune_csv_loader.py @@ -0,0 +1,179 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import logging +import os +import shutil +from csv import DictWriter +from typing import Any, Dict + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.base_loader import Loader +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.serializers import neptune_serializer +from databuilder.utils.closer import Closer + +LOGGER = logging.getLogger(__name__) + +PUBLISHED_TAG_PROPERTY_NAME = 'published_tag' + + +class FSNeptuneCSVLoader(Loader): + """ + Write node and relationship CSV file(s) that can be consumed by + NeptuneCsvPublisher. + It assumes that the record it consumes is instance of GraphSerializable + """ + # Config keys + NODE_DIR_PATH = 'node_dir_path' + RELATION_DIR_PATH = 'relationship_dir_path' + FORCE_CREATE_DIR = 'force_create_directory' + SHOULD_DELETE_CREATED_DIR = 'delete_created_directories' + JOB_PUBLISHER_TAG = 'job_publisher_tag' + + _DEFAULT_CONFIG = ConfigFactory.from_dict({ + SHOULD_DELETE_CREATED_DIR: True, + FORCE_CREATE_DIR: False + }) + + def __init__(self) -> None: + self._node_file_mapping: Dict[Any, DictWriter] = {} + self._relation_file_mapping: Dict[Any, DictWriter] = {} + self._closer = Closer() + + def init(self, conf: ConfigTree) -> None: + """ + Initializing FSNeptuneCSVLoader by creating directory for node files + and relationship files. Note that the directory defined in + configuration should not exist. + """ + conf = conf.with_fallback(FSNeptuneCSVLoader._DEFAULT_CONFIG) + + self._node_dir = conf.get_string(FSNeptuneCSVLoader.NODE_DIR_PATH) + self._relation_dir = conf.get_string(FSNeptuneCSVLoader.RELATION_DIR_PATH) + + self._delete_created_dir = conf.get_bool(FSNeptuneCSVLoader.SHOULD_DELETE_CREATED_DIR) + self._force_create_dir = conf.get_bool(FSNeptuneCSVLoader.FORCE_CREATE_DIR) + self._create_directory(self._node_dir) + self._create_directory(self._relation_dir) + self.job_publisher_tag = conf.get_string(FSNeptuneCSVLoader.JOB_PUBLISHER_TAG) + + def _create_directory(self, path: str) -> None: + """ + Validate directory does not exist, creates it, register deletion of + created directory function to Job.closer. + """ + if os.path.exists(path): + if self._force_create_dir: + LOGGER.info('Directory exist. Deleting directory {}'.format(path)) + shutil.rmtree(path) + else: + raise RuntimeError('Directory should not exist: {}'.format(path)) + + os.makedirs(path) + + def _delete_dir() -> None: + if not self._delete_created_dir: + LOGGER.warn('Skip Deleting directory {}'.format(path)) + return + + LOGGER.info('Deleting directory {}'.format(path)) + shutil.rmtree(path) + + # Directory should be deleted after publish is finished + Job.closer.register(_delete_dir) + + def load(self, csv_serializable: GraphSerializable) -> None: + """ + Writes GraphSerializable into CSV files. + There are multiple CSV files that this method writes. + This is because there're not only node and relationship, but also it + can also have different nodes, and relationships. + Common pattern for both nodes and relations: + 1. retrieve csv row (a dict where keys represent a header, + values represent a row) + 2. using this dict to get a appropriate csv writer and write to it. + 3. repeat 1 and 2 + :param csv_serializable: + :return: + """ + + node = csv_serializable.next_node() + while node: + + node.attributes[PUBLISHED_TAG_PROPERTY_NAME] = self.job_publisher_tag + node_dict = neptune_serializer.convert_node(node) + if node_dict: + key = (node.label, len(node_dict)) + file_suffix = '{}_{}'.format(*key) + node_writer = self._get_writer( + node_dict, + self._node_file_mapping, + key, + self._node_dir, + file_suffix + ) + node_writer.writerow(node_dict) + node = csv_serializable.next_node() + + relation = csv_serializable.next_relation() + while relation: + relation.attributes[PUBLISHED_TAG_PROPERTY_NAME] = self.job_publisher_tag + relation_dicts = neptune_serializer.convert_relationship(relation) + if relation_dicts: + key2 = (relation.start_label, + relation.end_label, + relation.type, + len(relation_dicts[0])) + + file_suffix = '{}_{}_{}'.format(key2[0], key2[1], key2[2]) + relation_writer = self._get_writer(relation_dicts[0], + self._relation_file_mapping, + key2, + self._relation_dir, + file_suffix) + relation_writer.writerows(relation_dicts) + relation = csv_serializable.next_relation() + + def _get_writer(self, + csv_record_dict: Dict[str, Any], + file_mapping: Dict[Any, DictWriter], + key: Any, + dir_path: str, + file_suffix: str + ) -> DictWriter: + """ + Finds a writer based on csv record, key. + If writer does not exist, it's creates a csv writer and update the + mapping. + """ + writer = file_mapping.get(key) + if writer: + return writer + + LOGGER.info('Creating file for {}'.format(key)) + + file_out = open('{}/{}.csv'.format(dir_path, file_suffix), 'w', encoding='utf8') + writer = csv.DictWriter(file_out, fieldnames=csv_record_dict.keys(), quoting=csv.QUOTE_NONNUMERIC) + + def file_out_close() -> None: + LOGGER.info('Closing file IO {}'.format(file_out)) + file_out.close() + self._closer.register(file_out_close) + + writer.writeheader() + file_mapping[key] = writer + + return writer + + def close(self) -> None: + """ + Any closeable callable registered in _closer, it will close. + """ + self._closer.close() + + def get_scope(self) -> str: + return "loader.neptune_filesystem_csv" diff --git a/databuilder/databuilder/loader/generic_loader.py b/databuilder/databuilder/loader/generic_loader.py new file mode 100644 index 0000000000..38db553295 --- /dev/null +++ b/databuilder/databuilder/loader/generic_loader.py @@ -0,0 +1,53 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, Optional + +from pyhocon import ConfigTree + +from databuilder.loader.base_loader import Loader + +LOGGER = logging.getLogger(__name__) + +CALLBACK_FUNCTION = 'callback_function' + + +def log_call_back(record: Optional[Any]) -> None: + """ + A Sample callback function. Implement any function follows this function's signature to fit your needs. + :param record: + :return: + """ + LOGGER.info('record: %s', record) + + +class GenericLoader(Loader): + """ + Loader class to call back a function provided by user + """ + + def init(self, conf: ConfigTree) -> None: + """ + Initialize file handlers from conf + :param conf: + """ + self.conf = conf + self._callback_func = self.conf.get(CALLBACK_FUNCTION, log_call_back) + + def load(self, record: Optional[Any]) -> None: + """ + Write record to function + :param record: + :return: + """ + if not record: + return + + self._callback_func(record) + + def close(self) -> None: + pass + + def get_scope(self) -> str: + return "loader.generic" diff --git a/databuilder/databuilder/models/__init__.py b/databuilder/databuilder/models/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/application.py b/databuilder/databuilder/models/application.py new file mode 100644 index 0000000000..4e0f76884d --- /dev/null +++ b/databuilder/databuilder/models/application.py @@ -0,0 +1,242 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import ( + AtlasCommonParams, AtlasCommonTypes, AtlasTableTypes, +) +from amundsen_rds.models import RDSModel +from amundsen_rds.models.application import Application as RDSApplication, ApplicationTable as RDSApplicationTable + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class GenericApplication(GraphSerializable, TableSerializable, AtlasSerializable): + """ + An Application that generates or consumes a resource. + """ + + LABEL = 'Application' + DEFAULT_KEY_FORMAT = 'application://{application_type}/{application_id}' + + APP_URL = 'application_url' + APP_NAME = 'name' + APP_ID = 'id' + APP_DESCRIPTION = 'description' + + GENERATES_REL_TYPE = 'GENERATES' + DERIVED_FROM_REL_TYPE = 'DERIVED_FROM' + CONSUMES_REL_TYPE = 'CONSUMES' + CONSUMED_BY_REL_TYPE = 'CONSUMED_BY' + + LABELS_PERMITTED_TO_HAVE_USAGE = ['Table'] + + def __init__(self, + start_label: str, + start_key: str, + application_type: str, + application_id: str, + application_url: str, + application_description: Optional[str] = None, + app_key_override: Optional[str] = None, # for bw-compatibility only + generates_resource: bool = True, + ) -> None: + + if start_label not in GenericApplication.LABELS_PERMITTED_TO_HAVE_USAGE: + raise Exception(f'applications associated with {start_label} are not supported') + + self.start_label = start_label + self.start_key = start_key + self.application_type = application_type + self.application_id = application_id + self.application_url = application_url + self.application_description = application_description + self.application_key = app_key_override or GenericApplication.DEFAULT_KEY_FORMAT.format( + application_type=self.application_type, + application_id=self.application_id, + ) + self.generates_resource = generates_resource + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def create_next_node(self) -> Union[GraphNode, None]: + # creates new node + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create an application node + :return: + """ + attrs = { + GenericApplication.APP_NAME: self.application_type, + GenericApplication.APP_ID: self.application_id, + GenericApplication.APP_URL: self.application_url, + } + if self.application_description: + attrs[GenericApplication.APP_DESCRIPTION] = self.application_description + + yield GraphNode( + key=self.application_key, + label=GenericApplication.LABEL, + attributes=attrs, + ) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relations between application and table nodes + :return: + """ + graph_relationship = GraphRelationship( + start_key=self.start_key, + start_label=self.start_label, + end_key=self.application_key, + end_label=GenericApplication.LABEL, + type=(GenericApplication.DERIVED_FROM_REL_TYPE if self.generates_resource + else GenericApplication.CONSUMED_BY_REL_TYPE), + reverse_type=(GenericApplication.GENERATES_REL_TYPE if self.generates_resource + else GenericApplication.CONSUMES_REL_TYPE), + attributes={} + ) + yield graph_relationship + + # TODO: support consuming/producing relationships and multiple apps per resource + def _create_record_iterator(self) -> Iterator[RDSModel]: + yield RDSApplication( + rk=self.application_key, + application_url=self.application_url, + name=self.application_type, + id=self.application_id, + description=self.application_description or '', + ) + + yield RDSApplicationTable( + rk=self.start_key, + application_rk=self.application_key, + ) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + group_attrs_mapping = [ + (AtlasCommonParams.qualified_name, self.application_key), + ('name', self.application_type), + ('id', self.application_id), + ('description', self.application_description or ''), + ('application_url', self.application_url) + ] + + entity_attrs = get_entity_attrs(group_attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.application, + operation=AtlasSerializedEntityOperation.CREATE, + relationships=None, + attributes=entity_attrs, + ) + + yield entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + # TODO: support consuming/producing relationships and multiple apps per resource + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + yield AtlasRelationship( + relationshipType=AtlasRelationshipTypes.table_application, + entityType1=AtlasTableTypes.table, + entityQualifiedName1=self.start_key, + entityType2=AtlasCommonTypes.application, + entityQualifiedName2=self.application_key, + attributes={} + ) + + +class AirflowApplication(GenericApplication): + + ID_FORMAT = '{dag}/{task}' + KEY_FORMAT = 'application://{cluster}.airflow/{id}' + DESCRIPTION_FORMAT = 'Airflow with id {id}' + + def __init__(self, + task_id: str, + dag_id: str, + application_url_template: str, + db_name: str = 'hive', + cluster: str = 'gold', + schema: str = '', + table_name: str = '', + application_type: str = 'Airflow', + exec_date: str = '', + generates_table: bool = True, + ) -> None: + + self.database = db_name + self.cluster = cluster + self.schema = schema + self.table = table_name + self.dag = dag_id + self.task = task_id + + airflow_app_id = AirflowApplication.ID_FORMAT.format(dag=dag_id, task=task_id) + GenericApplication.__init__( + self, + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self.get_table_model_key(), + application_type=application_type, + application_id=airflow_app_id, + application_url=application_url_template.format(dag_id=dag_id), + application_description=AirflowApplication.DESCRIPTION_FORMAT.format(id=airflow_app_id), + app_key_override=AirflowApplication.KEY_FORMAT.format(cluster=cluster, id=airflow_app_id), + generates_resource=generates_table, + ) + + def get_table_model_key(self) -> str: + return TableMetadata.TABLE_KEY_FORMAT.format( + db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.table, + ) + + +# Alias for backwards compatibility +Application = AirflowApplication diff --git a/databuilder/databuilder/models/atlas_entity.py b/databuilder/databuilder/models/atlas_entity.py new file mode 100644 index 0000000000..45423f215a --- /dev/null +++ b/databuilder/databuilder/models/atlas_entity.py @@ -0,0 +1,14 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +AtlasEntity = namedtuple( + 'AtlasEntity', + [ + 'operation', + 'typeName', + 'relationships', + 'attributes' + ] +) diff --git a/databuilder/databuilder/models/atlas_relationship.py b/databuilder/databuilder/models/atlas_relationship.py new file mode 100644 index 0000000000..1700eea145 --- /dev/null +++ b/databuilder/databuilder/models/atlas_relationship.py @@ -0,0 +1,15 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +from collections import namedtuple + +AtlasRelationship = namedtuple( + 'AtlasRelationship', + [ + 'relationshipType', + 'entityType1', + 'entityQualifiedName1', + 'entityType2', + 'entityQualifiedName2', + 'attributes', + ], +) diff --git a/databuilder/databuilder/models/atlas_serializable.py b/databuilder/databuilder/models/atlas_serializable.py new file mode 100644 index 0000000000..4c689ac4fb --- /dev/null +++ b/databuilder/databuilder/models/atlas_serializable.py @@ -0,0 +1,86 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Union + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship + + +class AtlasSerializable(object, metaclass=abc.ABCMeta): + """ + A Serializable abstract class asks subclass to implement next node or + next relation in dict form so that it can be serialized to CSV file. + + Any model class that needs to be pushed to a atlas should inherit this class. + """ + + @abc.abstractmethod + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + """ + Creates Atlas entity the process that consumes this class takes the output + serializes to atlas entity + + :return: a Atlas entity or None if no more records to serialize + """ + raise NotImplementedError + + @abc.abstractmethod + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + """ + Creates AtlasRelationship the process that consumes this class takes the output + serializes to the desired graph database. + + :return: a AtlasRelationship or None if no more record to serialize + """ + raise NotImplementedError + + def _validate_atlas_entity(self, entity: AtlasEntity) -> None: + operation, entity_type, relation, attributes = entity + + if entity_type is None: + raise ValueError('Required header missing: entityType') + + if operation is None: + raise ValueError('Required header missing: operation') + + if attributes is None: + raise ValueError('Required header missing: attributes') + + if 'qualifiedName' not in attributes: + raise ValueError('Attribute qualifiedName is missing') + + def _validate_atlas_relation(self, relation: AtlasRelationship) -> None: + relation_type, entity_type_1, qualified_name_1, entity_type_2, qualified_name_2, _ = relation + + if relation_type is None: + raise ValueError(f'Required header missing. Missing: {AtlasRelationship.relationshipType}') + + if entity_type_1 is None: + raise ValueError(f'Required header missing. Missing: {AtlasRelationship.entityType1}') + + if qualified_name_1 is None: + raise ValueError(f'Required header missing. Missing: {AtlasRelationship.entityQualifiedName1}') + + if entity_type_2 is None: + raise ValueError(f'Required header missing. Missing: {AtlasRelationship.entityType2}') + + if qualified_name_2 is None: + raise ValueError(f'Required header missing. Missing: {AtlasRelationship.entityQualifiedName2}') + + def next_atlas_entity(self) -> Union[AtlasEntity, None]: + entity_dict = self.create_next_atlas_entity() + if not entity_dict: + return None + + self._validate_atlas_entity(entity_dict) + return entity_dict + + def next_atlas_relation(self) -> Union[AtlasRelationship, None]: + relation_dict = self.create_next_atlas_relation() + if not relation_dict: + return None + + self._validate_atlas_relation(relation_dict) + return relation_dict diff --git a/databuilder/databuilder/models/badge.py b/databuilder/databuilder/models/badge.py new file mode 100644 index 0000000000..81a2a39902 --- /dev/null +++ b/databuilder/databuilder/models/badge.py @@ -0,0 +1,220 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, List, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasCommonTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.badge import Badge as RDSBadge +from amundsen_rds.models.column import ColumnBadge as RDSColumnBadge +from amundsen_rds.models.dashboard import DashboardBadge as RDSDashboardBadge +from amundsen_rds.models.table import TableBadge as RDSTableBadge + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class Badge: + def __init__(self, name: str, category: str): + # Amundsen UI always formats badge display with first letter capitalized while other letters are lowercase. + # Clicking table badges in UI always results in searching lower cases badges + # https://github.com/amundsen-io/amundsen/blob/6ec9b398634264e52089bb9e1b7d76a6fb6a35a4/frontend/amundsen_application/static/js/components/BadgeList/index.tsx#L56 + # If badges stored in neo4j are not lowercase, they won't be searchable in UI. + self.name = name.lower() + self.category = category.lower() + + def __repr__(self) -> str: + return f'Badge({self.name!r}, {self.category!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Badge): + return NotImplemented + return self.name == other.name and self.category == other.category + + +class BadgeMetadata(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Badge model. + """ + BADGE_NODE_LABEL = 'Badge' + BADGE_KEY_FORMAT = '{badge}' + BADGE_CATEGORY = 'category' + + # Relation between entity and badge + BADGE_RELATION_TYPE = 'HAS_BADGE' + INVERSE_BADGE_RELATION_TYPE = 'BADGE_FOR' + + LABELS_PERMITTED_TO_HAVE_BADGE = ['Table', 'Dashboard', 'Column', 'Feature', 'Type_Metadata'] + + def __init__(self, + start_label: str, + start_key: str, + badges: List[Badge], + ): + if start_label not in BadgeMetadata.LABELS_PERMITTED_TO_HAVE_BADGE: + raise Exception(f'badges for {start_label} are not supported') + self.start_label = start_label + self.start_key = start_key + self.badges = badges + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def __repr__(self) -> str: + return f'BadgeMetadata({self.start_label!r}, {self.start_key!r}, {self.badges!r})' + + def create_next_node(self) -> Optional[GraphNode]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + @staticmethod + def get_badge_key(name: str) -> str: + if not name: + return '' + return BadgeMetadata.BADGE_KEY_FORMAT.format(badge=name) + + def get_badge_nodes(self) -> List[GraphNode]: + nodes = [] + for badge in self.badges: + if badge: + node = GraphNode( + key=self.get_badge_key(badge.name), + label=self.BADGE_NODE_LABEL, + attributes={ + self.BADGE_CATEGORY: badge.category + } + ) + nodes.append(node) + return nodes + + def get_badge_relations(self) -> List[GraphRelationship]: + relations = [] + for badge in self.badges: + relation = GraphRelationship( + start_label=self.start_label, + end_label=self.BADGE_NODE_LABEL, + start_key=self.start_key, + end_key=self.get_badge_key(badge.name), + type=self.BADGE_RELATION_TYPE, + reverse_type=self.INVERSE_BADGE_RELATION_TYPE, + attributes={} + ) + relations.append(relation) + return relations + + def get_badge_records(self) -> List[RDSModel]: + records = [] + for badge in self.badges: + if badge: + record = RDSBadge( + rk=self.get_badge_key(badge.name), + category=badge.category + ) + records.append(record) + + return records + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create badge nodes + :return: + """ + nodes = self.get_badge_nodes() + for node in nodes: + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relations = self.get_badge_relations() + for relation in relations: + yield relation + + def _create_record_iterator(self) -> Iterator[RDSModel]: + records = self.get_badge_records() + for record in records: + yield record + + if self.start_label == 'Table': + table_badge_record = RDSTableBadge(table_rk=self.start_key, badge_rk=record.rk) + yield table_badge_record + elif self.start_label == 'Column': + column_badge_record = RDSColumnBadge(column_rk=self.start_key, badge_rk=record.rk) + yield column_badge_record + elif self.start_label == 'Dashboard': + dashboard_badge_record = RDSDashboardBadge(dashboard_rk=self.start_key, badge_rk=record.rk) + yield dashboard_badge_record + + def _create_atlas_classification_entity(self, badge: Badge) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, badge.name), + ('category', badge.category), + ('name', badge.name) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.badge, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def _create_atlas_classification_relation(self, badge: Badge) -> AtlasRelationship: + table_relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.badge, + entityType1=AtlasCommonTypes.data_set, + entityQualifiedName1=self.start_key, + entityType2=AtlasRelationshipTypes.badge, + entityQualifiedName2=badge.name, + attributes={} + ) + + return table_relationship + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + for badge in self.badges: + yield self._create_atlas_classification_relation(badge) + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + for badge in self.badges: + yield self._create_atlas_classification_entity(badge) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None diff --git a/databuilder/databuilder/models/cluster/__init__.py b/databuilder/databuilder/models/cluster/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/cluster/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/cluster/cluster_constants.py b/databuilder/databuilder/models/cluster/cluster_constants.py new file mode 100644 index 0000000000..2a16902541 --- /dev/null +++ b/databuilder/databuilder/models/cluster/cluster_constants.py @@ -0,0 +1,9 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +CLUSTER_NODE_LABEL = 'Cluster' + +CLUSTER_RELATION_TYPE = 'CLUSTER' +CLUSTER_REVERSE_RELATION_TYPE = 'CLUSTER_OF' + +CLUSTER_NAME_PROP_KEY = 'name' diff --git a/databuilder/databuilder/models/dashboard/__init__.py b/databuilder/databuilder/models/dashboard/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/dashboard/dashboard_chart.py b/databuilder/databuilder/models/dashboard/dashboard_chart.py new file mode 100644 index 0000000000..1afdefc910 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_chart.py @@ -0,0 +1,200 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasDashboardTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardChart as RDSDashboardChart + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasSerializedEntityOperation + +LOGGER = logging.getLogger(__name__) + + +class DashboardChart(GraphSerializable, TableSerializable, AtlasSerializable): + """ + A model that encapsulate Dashboard's charts + """ + DASHBOARD_CHART_LABEL = 'Chart' + DASHBOARD_CHART_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}/' \ + '{dashboard_id}/query/{query_id}/chart/{chart_id}' + CHART_RELATION_TYPE = 'HAS_CHART' + CHART_REVERSE_RELATION_TYPE = 'CHART_OF' + + def __init__(self, + dashboard_group_id: Optional[str], + dashboard_id: Optional[str], + query_id: str, + chart_id: str, + chart_name: Optional[str] = None, + chart_type: Optional[str] = None, + chart_url: Optional[str] = None, + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + self._dashboard_group_id = dashboard_group_id + self._dashboard_id = dashboard_id + self._query_id = query_id + self._chart_id = chart_id if chart_id else chart_name + self._chart_name = chart_name + self._chart_type = chart_type + self._chart_url = chart_url + self._product = product + self._cluster = cluster + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { + 'id': self._chart_id + } + + if self._chart_name: + node_attributes['name'] = self._chart_name + + if self._chart_type: + node_attributes['type'] = self._chart_type + + if self._chart_url: + node_attributes['url'] = self._chart_url + + node = GraphNode( + key=self._get_chart_node_key(), + label=DashboardChart.DASHBOARD_CHART_LABEL, + attributes=node_attributes + ) + yield node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardQuery.DASHBOARD_QUERY_LABEL, + start_key=DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + query_id=self._query_id + ), + end_label=DashboardChart.DASHBOARD_CHART_LABEL, + end_key=self._get_chart_node_key(), + type=DashboardChart.CHART_RELATION_TYPE, + reverse_type=DashboardChart.CHART_REVERSE_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _get_chart_node_key(self) -> str: + return DashboardChart.DASHBOARD_CHART_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + query_id=self._query_id, + chart_id=self._chart_id + ) + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = RDSDashboardChart( + rk=self._get_chart_node_key(), + id=self._chart_id, + query_rk=DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + query_id=self._query_id + ) + ) + if self._chart_name: + record.name = self._chart_name + if self._chart_type: + record.type = self._chart_type + if self._chart_url: + record.url = self._chart_url + + yield record + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + # Chart + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_chart_node_key()), + ('name', self._chart_name), + ('type', self._chart_type), + ('url', self._chart_url) + ] + + chart_entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'query', + AtlasDashboardTypes.query, + DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + query_id=self._query_id + ) + ) + + chart_entity = AtlasEntity( + typeName=AtlasDashboardTypes.chart, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=chart_entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + yield chart_entity + + def __repr__(self) -> str: + return f'DashboardChart({self._dashboard_group_id!r}, {self._dashboard_id!r}, ' \ + f'{self._query_id!r}, {self._chart_id!r}, {self._chart_name!r}, {self._chart_type!r}, ' \ + f'{self._chart_url!r}, {self._product!r}, {self._cluster!r})' diff --git a/databuilder/databuilder/models/dashboard/dashboard_execution.py b/databuilder/databuilder/models/dashboard/dashboard_execution.py new file mode 100644 index 0000000000..02e61a2f69 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_execution.py @@ -0,0 +1,176 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasDashboardTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardExecution as RDSDashboardExecution + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasSerializedEntityOperation + +LOGGER = logging.getLogger(__name__) + + +class DashboardExecution(GraphSerializable, TableSerializable, AtlasSerializable): + """ + A model that encapsulate Dashboard's execution timestamp in epoch and execution state + """ + DASHBOARD_EXECUTION_LABEL = 'Execution' + DASHBOARD_EXECUTION_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}/' \ + '{dashboard_id}/execution/{execution_id}' + DASHBOARD_EXECUTION_RELATION_TYPE = 'EXECUTED' + EXECUTION_DASHBOARD_RELATION_TYPE = 'EXECUTION_OF' + + LAST_EXECUTION_ID = '_last_execution' + LAST_SUCCESSFUL_EXECUTION_ID = '_last_successful_execution' + + def __init__(self, + dashboard_group_id: Optional[str], + dashboard_id: Optional[str], + execution_timestamp: int, + execution_state: str, + execution_id: str = LAST_EXECUTION_ID, + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + self._dashboard_group_id = dashboard_group_id + self._dashboard_id = dashboard_id + self._execution_timestamp = execution_timestamp + self._execution_state = execution_state + self._execution_id = execution_id + self._product = product + self._cluster = cluster + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node = GraphNode( + key=self._get_last_execution_node_key(), + label=DashboardExecution.DASHBOARD_EXECUTION_LABEL, + attributes={ + 'timestamp': self._execution_timestamp, + 'state': self._execution_state + } + ) + yield node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + end_label=DashboardExecution.DASHBOARD_EXECUTION_LABEL, + end_key=self._get_last_execution_node_key(), + type=DashboardExecution.DASHBOARD_EXECUTION_RELATION_TYPE, + reverse_type=DashboardExecution.EXECUTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + yield RDSDashboardExecution( + rk=self._get_last_execution_node_key(), + timestamp=self._execution_timestamp, + state=self._execution_state, + dashboard_rk=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ) + + def _get_last_execution_node_key(self) -> str: + return DashboardExecution.DASHBOARD_EXECUTION_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + execution_id=self._execution_id + ) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_last_execution_node_key()), + ('state', self._execution_state), + ('timestamp', self._execution_timestamp) + ] + + query_entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'dashboard', + AtlasDashboardTypes.metadata, + DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ) + + execution_entity = AtlasEntity( + typeName=AtlasDashboardTypes.execution, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=query_entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + yield execution_entity + + def __repr__(self) -> str: + return f'DashboardExecution({self._dashboard_group_id!r}, {self._dashboard_id!r}, ' \ + f'{self._execution_timestamp!r}, {self._execution_state!r}, ' \ + f'{self._execution_id!r}, {self._product!r}, {self._cluster!r})' diff --git a/databuilder/databuilder/models/dashboard/dashboard_last_modified.py b/databuilder/databuilder/models/dashboard/dashboard_last_modified.py new file mode 100644 index 0000000000..f6331de714 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_last_modified.py @@ -0,0 +1,158 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasDashboardTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardTimestamp as RDSDashboardTimestamp + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.models.timestamp import timestamp_constants +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasSerializedEntityOperation + +LOGGER = logging.getLogger(__name__) + + +class DashboardLastModifiedTimestamp(GraphSerializable, TableSerializable, AtlasSerializable): + """ + A model that encapsulate Dashboard's last modified timestamp in epoch + """ + + DASHBOARD_LAST_MODIFIED_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}/' \ + '{dashboard_id}/_last_modified_timestamp' + + def __init__(self, + dashboard_group_id: Optional[str], + dashboard_id: Optional[str], + last_modified_timestamp: int, + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + self._dashboard_group_id = dashboard_group_id + self._dashboard_id = dashboard_id + self._last_modified_timestamp = last_modified_timestamp + self._product = product + self._cluster = cluster + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { + timestamp_constants.TIMESTAMP_PROPERTY: self._last_modified_timestamp, + timestamp_constants.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name + } + node = GraphNode( + key=self._get_last_modified_node_key(), + label=timestamp_constants.NODE_LABEL, + attributes=node_attributes + ) + yield node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_key=self._get_last_modified_node_key(), + end_label=timestamp_constants.NODE_LABEL, + type=timestamp_constants.LASTUPDATED_RELATION_TYPE, + reverse_type=timestamp_constants.LASTUPDATED_REVERSE_RELATION_TYPE, + attributes={} + ) + yield relationship + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + + # last modified + attrs_mapping = [ + ( + AtlasCommonParams.qualified_name, DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ), + (AtlasCommonParams.last_modified_timestamp, self._last_modified_timestamp), + ] + + dashboard_entity_attrs = get_entity_attrs(attrs_mapping) + + last_modified = AtlasEntity( + typeName=AtlasDashboardTypes.metadata, + operation=AtlasSerializedEntityOperation.UPDATE, + relationships=None, + attributes=dashboard_entity_attrs + ) + yield last_modified + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + yield RDSDashboardTimestamp( + rk=self._get_last_modified_node_key(), + timestamp=self._last_modified_timestamp, + name=timestamp_constants.TimestampName.last_updated_timestamp.name, + dashboard_rk=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ) + + def _get_last_modified_node_key(self) -> str: + return DashboardLastModifiedTimestamp.DASHBOARD_LAST_MODIFIED_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + ) + + def __repr__(self) -> str: + return f'DashboardLastModifiedTimestamp({self._dashboard_group_id!r}, {self._dashboard_id!r}, ' \ + f'{self._last_modified_timestamp!r}, {self._product!r}, {self._cluster!r})' diff --git a/databuilder/databuilder/models/dashboard/dashboard_metadata.py b/databuilder/databuilder/models/dashboard/dashboard_metadata.py new file mode 100644 index 0000000000..65e679e252 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_metadata.py @@ -0,0 +1,437 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterator, List, Optional, Set, Tuple, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasDashboardTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import ( + Dashboard as RDSDashboard, DashboardCluster as RDSDashboardCluster, DashboardDescription as RDSDashboardDescription, + DashboardGroup as RDSDashboardGroup, DashboardGroupDescription as RDSDashboardGroupDescription, + DashboardTag as RDSDashboardTag, +) +from amundsen_rds.models.tag import Tag as RDSTag + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.cluster import cluster_constants +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +# TODO: We could separate TagMetadata from table_metadata to own module +from databuilder.models.table_metadata import TagMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasSerializedEntityOperation + + +class DashboardMetadata(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Dashboard metadata including dashboard group name, dashboardgroup description, dashboard description, + and tags. + + Some other metadata e.g. Owners and last-reload/modified times are provided by other models + e.g. DashboardOwner + + It implements Neo4jCsvSerializable so that it can be serialized to produce + Dashboard, Tag, Description and relations between those. Additionally, it will create a + Dashboardgroup with relationships to the Dashboard. + """ + CLUSTER_KEY_FORMAT = '{product}_dashboard://{cluster}' + CLUSTER_DASHBOARD_GROUP_RELATION_TYPE = 'DASHBOARD_GROUP' + DASHBOARD_GROUP_CLUSTER_RELATION_TYPE = 'DASHBOARD_GROUP_OF' + + DASHBOARD_NODE_LABEL = 'Dashboard' + DASHBOARD_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group}/{dashboard_name}' + DASHBOARD_NAME = 'name' + DASHBOARD_CREATED_TIME_STAMP = 'created_timestamp' + DASHBOARD_GROUP_URL = 'dashboard_group_url' + DASHBOARD_URL = 'dashboard_url' + + DASHBOARD_DESCRIPTION_NODE_LABEL = 'Description' + DASHBOARD_DESCRIPTION = 'description' + DASHBOARD_DESCRIPTION_FORMAT = \ + '{product}_dashboard://{cluster}.{dashboard_group}/{dashboard_name}/_description' + DASHBOARD_DESCRIPTION_RELATION_TYPE = 'DESCRIPTION' + DESCRIPTION_DASHBOARD_RELATION_TYPE = 'DESCRIPTION_OF' + + DASHBOARD_GROUP_NODE_LABEL = 'Dashboardgroup' + DASHBOARD_GROUP_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group}' + DASHBOARD_GROUP_DASHBOARD_RELATION_TYPE = 'DASHBOARD' + DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE = 'DASHBOARD_OF' + + DASHBOARD_GROUP_DESCRIPTION_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group}/_description' + + DASHBOARD_TAG_RELATION_TYPE = 'TAG' + TAG_DASHBOARD_RELATION_TYPE = 'TAG_OF' + + ATLAS_DASHBOARD_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}/{dashboard_id}' + ATLAS_DASHBOARD_GROUP_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}' + + serialized_nodes: Set[Any] = set() + serialized_rels: Set[Any] = set() + + def __init__(self, + dashboard_group: str, + dashboard_name: str, + description: Union[str, None], + tags: Optional[List] = None, + cluster: str = 'gold', + product: Optional[str] = '', + dashboard_group_id: Optional[str] = None, + dashboard_id: Optional[str] = None, + dashboard_group_description: Optional[str] = None, + created_timestamp: Optional[int] = None, + dashboard_group_url: Optional[str] = None, + dashboard_url: Optional[str] = None, + **kwargs: Any + ) -> None: + + self.dashboard_group = dashboard_group + self.dashboard_name = dashboard_name + self.dashboard_group_id = dashboard_group_id if dashboard_group_id else dashboard_group + self.dashboard_id = dashboard_id if dashboard_id else dashboard_name + self.description = description + self.tags = tags + self.product = product + self.cluster = cluster + self.dashboard_group_description = dashboard_group_description + self.created_timestamp = created_timestamp + self.dashboard_group_url = dashboard_group_url + self.dashboard_url = dashboard_url + self._processed_cluster: Set[str] = set() + self._processed_dashboard_group: Set[str] = set() + self._node_iterator = self._create_next_node() + self._relation_iterator = self._create_next_relation() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def __repr__(self) -> str: + return f'DashboardMetadata(' \ + f'{self.dashboard_group!r}, {self.dashboard_name!r}, {self.description!r}, {self.tags!r}, ' \ + f'{self.dashboard_group_id!r}, {self.dashboard_id!r}, {self.dashboard_group_description!r}, ' \ + f'{self.created_timestamp!r}, {self.dashboard_group_url!r}, {self.dashboard_url!r})' + + def _get_cluster_key(self) -> str: + return DashboardMetadata.CLUSTER_KEY_FORMAT.format(cluster=self.cluster, + product=self.product) + + def _get_dashboard_key(self) -> str: + return DashboardMetadata.DASHBOARD_KEY_FORMAT.format(dashboard_group=self.dashboard_group_id, + dashboard_name=self.dashboard_id, + cluster=self.cluster, + product=self.product) + + def _get_dashboard_description_key(self) -> str: + return DashboardMetadata.DASHBOARD_DESCRIPTION_FORMAT.format(dashboard_group=self.dashboard_group_id, + dashboard_name=self.dashboard_id, + cluster=self.cluster, + product=self.product) + + def _get_dashboard_group_description_key(self) -> str: + return DashboardMetadata.DASHBOARD_GROUP_DESCRIPTION_KEY_FORMAT.format(dashboard_group=self.dashboard_group_id, + cluster=self.cluster, + product=self.product) + + def _get_dashboard_group_key(self) -> str: + return DashboardMetadata.DASHBOARD_GROUP_KEY_FORMAT.format(dashboard_group=self.dashboard_group_id, + cluster=self.cluster, + product=self.product) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + + # dashboard group + group_attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_dashboard_group_key()), + ('name', self.dashboard_group), + ('id', self.dashboard_group_id), + ('description', self.dashboard_group_description), + ('url', self.dashboard_group_url), + ] + dashboard_group_entity_attrs = get_entity_attrs(group_attrs_mapping) + + dashboard_group_entity = AtlasEntity( + typeName=AtlasDashboardTypes.group, + operation=AtlasSerializedEntityOperation.CREATE, + relationships=None, + attributes=dashboard_group_entity_attrs, + ) + + yield dashboard_group_entity + + # dashboard + attrs_mapping: List[Tuple[Any, Any]] = [ + (AtlasCommonParams.qualified_name, self._get_dashboard_key()), + ('name', self.dashboard_name), + ('description', self.description), + ('url', self.dashboard_url), + ('cluster', self.cluster), + ('product', self.product), + (AtlasCommonParams.created_timestamp, self.created_timestamp), + ] + + dashboard_entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'group', + AtlasDashboardTypes.group, + self._get_dashboard_group_key(), + ) + + dashboard_entity = AtlasEntity( + typeName=AtlasDashboardTypes.metadata, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=dashboard_entity_attrs, + relationships=get_entity_relationships(relationship_list), + ) + yield dashboard_entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_next_node(self) -> Iterator[GraphNode]: + # Cluster node + if not self._get_cluster_key() in self._processed_cluster: + self._processed_cluster.add(self._get_cluster_key()) + cluster_node = GraphNode( + key=self._get_cluster_key(), + label=cluster_constants.CLUSTER_NODE_LABEL, + attributes={ + cluster_constants.CLUSTER_NAME_PROP_KEY: self.cluster + } + ) + yield cluster_node + + # Dashboard node attributes + dashboard_node_attributes: Dict[str, Any] = { + DashboardMetadata.DASHBOARD_NAME: self.dashboard_name, + } + if self.created_timestamp: + dashboard_node_attributes[DashboardMetadata.DASHBOARD_CREATED_TIME_STAMP] = self.created_timestamp + + if self.dashboard_url: + dashboard_node_attributes[DashboardMetadata.DASHBOARD_URL] = self.dashboard_url + + dashboard_node = GraphNode( + key=self._get_dashboard_key(), + label=DashboardMetadata.DASHBOARD_NODE_LABEL, + attributes=dashboard_node_attributes + ) + + yield dashboard_node + + # Dashboard group + if self.dashboard_group and not self._get_dashboard_group_key() in self._processed_dashboard_group: + self._processed_dashboard_group.add(self._get_dashboard_group_key()) + dashboard_group_node_attributes = { + DashboardMetadata.DASHBOARD_NAME: self.dashboard_group, + } + + if self.dashboard_group_url: + dashboard_group_node_attributes[DashboardMetadata.DASHBOARD_GROUP_URL] = self.dashboard_group_url + + dashboard_group_node = GraphNode( + key=self._get_dashboard_group_key(), + label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + attributes=dashboard_group_node_attributes + ) + + yield dashboard_group_node + + # Dashboard group description + if self.dashboard_group_description: + dashboard_group_description_node = GraphNode( + key=self._get_dashboard_group_description_key(), + label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + attributes={ + DashboardMetadata.DASHBOARD_DESCRIPTION: self.dashboard_group_description + } + ) + yield dashboard_group_description_node + + # Dashboard description node + if self.description: + dashboard_description_node = GraphNode( + key=self._get_dashboard_description_key(), + label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + attributes={ + DashboardMetadata.DASHBOARD_DESCRIPTION: self.description + } + ) + yield dashboard_description_node + + # Dashboard tag node + if self.tags: + for tag in self.tags: + dashboard_tag_node = GraphNode( + key=TagMetadata.get_tag_key(tag), + label=TagMetadata.TAG_NODE_LABEL, + attributes={ + TagMetadata.TAG_TYPE: 'dashboard' + } + ) + yield dashboard_tag_node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_next_relation(self) -> Iterator[GraphRelationship]: + # Cluster <-> Dashboard group + cluster_dashboard_group_relationship = GraphRelationship( + start_label=cluster_constants.CLUSTER_NODE_LABEL, + start_key=self._get_cluster_key(), + end_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + end_key=self._get_dashboard_group_key(), + type=DashboardMetadata.CLUSTER_DASHBOARD_GROUP_RELATION_TYPE, + reverse_type=DashboardMetadata.DASHBOARD_GROUP_CLUSTER_RELATION_TYPE, + attributes={} + ) + yield cluster_dashboard_group_relationship + + # Dashboard group > Dashboard group description relation + if self.dashboard_group_description: + dashboard_group_description_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + start_key=self._get_dashboard_group_key(), + end_label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + end_key=self._get_dashboard_group_description_key(), + type=DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, + reverse_type=DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_group_description_relationship + + # Dashboard group > Dashboard relation + dashboard_group_dashboard_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=self._get_dashboard_group_key(), + type=DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE, + reverse_type=DashboardMetadata.DASHBOARD_GROUP_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_group_dashboard_relationship + + # Dashboard > Dashboard description relation + if self.description: + dashboard_description_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=self._get_dashboard_description_key(), + type=DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, + reverse_type=DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_description_relationship + + # Dashboard > Dashboard tag relation + if self.tags: + for tag in self.tags: + dashboard_tag_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=TagMetadata.TAG_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=TagMetadata.get_tag_key(tag), + type=DashboardMetadata.DASHBOARD_TAG_RELATION_TYPE, + reverse_type=DashboardMetadata.TAG_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_tag_relationship + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + # Cluster + if not self._get_cluster_key() in self._processed_cluster: + self._processed_cluster.add(self._get_cluster_key()) + yield RDSDashboardCluster( + rk=self._get_cluster_key(), + name=self.cluster + ) + + # Dashboard group + if self.dashboard_group and not self._get_dashboard_group_key() in self._processed_dashboard_group: + self._processed_dashboard_group.add(self._get_dashboard_group_key()) + dashboard_group_record = RDSDashboardGroup( + rk=self._get_dashboard_group_key(), + name=self.dashboard_group, + cluster_rk=self._get_cluster_key() + ) + if self.dashboard_group_url: + dashboard_group_record.dashboard_group_url = self.dashboard_group_url + + yield dashboard_group_record + + # Dashboard group description + if self.dashboard_group_description: + yield RDSDashboardGroupDescription( + rk=self._get_dashboard_group_description_key(), + description=self.dashboard_group_description, + dashboard_group_rk=self._get_dashboard_group_key() + ) + + # Dashboard + dashboard_record = RDSDashboard( + rk=self._get_dashboard_key(), + name=self.dashboard_name, + dashboard_group_rk=self._get_dashboard_group_key() + ) + if self.created_timestamp: + dashboard_record.created_timestamp = self.created_timestamp + + if self.dashboard_url: + dashboard_record.dashboard_url = self.dashboard_url + + yield dashboard_record + + # Dashboard description + if self.description: + yield RDSDashboardDescription( + rk=self._get_dashboard_description_key(), + description=self.description, + dashboard_rk=self._get_dashboard_key() + ) + + # Dashboard tag + if self.tags: + for tag in self.tags: + tag_record = RDSTag( + rk=TagMetadata.get_tag_key(tag), + tag_type='dashboard', + ) + yield tag_record + + dashboard_tag_record = RDSDashboardTag( + dashboard_rk=self._get_dashboard_key(), + tag_rk=TagMetadata.get_tag_key(tag) + ) + yield dashboard_tag_record diff --git a/databuilder/databuilder/models/dashboard/dashboard_owner.py b/databuilder/databuilder/models/dashboard/dashboard_owner.py new file mode 100644 index 0000000000..e9de26b1fa --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_owner.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardOwner as RDSDashboardOwner + +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.owner import Owner +from databuilder.models.user import User + + +class DashboardOwner(Owner): + """ + A model that encapsulate Dashboard's owner. + Note that it does not create new user as it has insufficient information about user but it builds relation + between User and Dashboard + """ + + def __init__(self, + dashboard_group_id: str, + dashboard_id: str, + email: str, + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + + Owner.__init__( + self, + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=product, + cluster=cluster, + dashboard_group=dashboard_group_id, + dashboard_name=dashboard_id + ), + owner_emails=[email] + ) + self._email = email + self._record_iterator = self._create_record_iterator() + + # override this because we do not want to create new User nodes from this model + def create_next_node(self) -> Union[GraphNode, None]: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + # override this because we do not want to create new User rows from this model + def _create_record_iterator(self) -> Iterator[RDSModel]: + yield RDSDashboardOwner( + user_rk=User.get_user_model_key(email=self._email), + dashboard_rk=self.start_key, + ) diff --git a/databuilder/databuilder/models/dashboard/dashboard_query.py b/databuilder/databuilder/models/dashboard/dashboard_query.py new file mode 100644 index 0000000000..0d8495b85c --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_query.py @@ -0,0 +1,191 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasDashboardTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardQuery as RDSDashboardQuery + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasSerializedEntityOperation + +LOGGER = logging.getLogger(__name__) + + +class DashboardQuery(GraphSerializable, TableSerializable, AtlasSerializable): + """ + A model that encapsulate Dashboard's query name + """ + DASHBOARD_QUERY_LABEL = 'Query' + DASHBOARD_QUERY_KEY_FORMAT = '{product}_dashboard://{cluster}.{dashboard_group_id}/' \ + '{dashboard_id}/query/{query_id}' + DASHBOARD_QUERY_RELATION_TYPE = 'HAS_QUERY' + QUERY_DASHBOARD_RELATION_TYPE = 'QUERY_OF' + + def __init__(self, + dashboard_group_id: Optional[str], + dashboard_id: Optional[str], + query_name: str, + query_id: Optional[str] = None, + url: Optional[str] = '', + query_text: Optional[str] = None, + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + self._dashboard_group_id = dashboard_group_id + self._dashboard_id = dashboard_id + self._query_name = query_name + self._query_id = query_id if query_id else query_name + self._url = url + self._query_text = query_text + self._product = product + self._cluster = cluster + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { + 'id': self._query_id, + 'name': self._query_name, + } + + if self._url: + node_attributes['url'] = self._url + + if self._query_text: + node_attributes['query_text'] = self._query_text + + node = GraphNode( + key=self._get_query_node_key(), + label=DashboardQuery.DASHBOARD_QUERY_LABEL, + attributes=node_attributes + ) + + yield node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardQuery.DASHBOARD_QUERY_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + end_key=self._get_query_node_key(), + type=DashboardQuery.DASHBOARD_QUERY_RELATION_TYPE, + reverse_type=DashboardQuery.QUERY_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = RDSDashboardQuery( + rk=self._get_query_node_key(), + id=self._query_id, + name=self._query_name, + dashboard_rk=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ) + if self._url: + record.url = self._url + if self._query_text: + record.query_text = self._query_text + + yield record + + def _get_query_node_key(self) -> str: + return DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group_id=self._dashboard_group_id, + dashboard_id=self._dashboard_id, + query_id=self._query_id + ) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + # Query + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_query_node_key()), + ('name', self._query_name), + ('id', self._query_id), + ('url', self._url), + ('queryText', self._query_text) + ] + + query_entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'dashboard', + AtlasDashboardTypes.metadata, + DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ) + ) + + query_entity = AtlasEntity( + typeName=AtlasDashboardTypes.query, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=query_entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + yield query_entity + + def __repr__(self) -> str: + return f'DashboardQuery({self._dashboard_group_id!r}, {self._dashboard_id!r}, {self._query_name!r}, ' \ + f'{self._query_id!r}, {self._url!r}, {self._query_text!r}, {self._product!r}, {self._cluster!r})' diff --git a/databuilder/databuilder/models/dashboard/dashboard_table.py b/databuilder/databuilder/models/dashboard/dashboard_table.py new file mode 100644 index 0000000000..bccb1a0c11 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_table.py @@ -0,0 +1,150 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from typing import ( + Any, Iterator, List, Optional, Union, +) + +from amundsen_common.utils.atlas import ( + AtlasDashboardTypes, AtlasTableKey, AtlasTableTypes, +) +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardTable as RDSDashboardTable + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.utils.atlas import AtlasRelationshipTypes + +LOGGER = logging.getLogger(__name__) + + +class DashboardTable(GraphSerializable, TableSerializable, AtlasSerializable): + """ + A model that link Dashboard with the tables used in various charts of the dashboard. + Note that it does not create new dashboard, table as it has insufficient information but it builds relation + between Tables and Dashboard + """ + + DASHBOARD_TABLE_RELATION_TYPE = 'DASHBOARD_WITH_TABLE' + TABLE_DASHBOARD_RELATION_TYPE = 'TABLE_OF_DASHBOARD' + + def __init__(self, + dashboard_group_id: str, + dashboard_id: str, + table_ids: List[str], + product: Optional[str] = '', + cluster: str = 'gold', + **kwargs: Any + ) -> None: + self._dashboard_group_id = dashboard_group_id + self._dashboard_id = dashboard_id + # A list of tables uri used in the dashboard + self._table_ids = table_ids + self._product = product + self._cluster = cluster + + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def create_next_node(self) -> Union[GraphNode, None]: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + if self._relation_iterator is None: + return None + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + for table_id in self._table_ids: + m = re.match(r'([^./]+)://([^./]+)\.([^./]+)\/([^./]+)', table_id) + if m: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=TableMetadata.TABLE_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + end_key=TableMetadata.TABLE_KEY_FORMAT.format( + db=m.group(1), + cluster=m.group(2), + schema=m.group(3), + tbl=m.group(4) + ), + type=DashboardTable.DASHBOARD_TABLE_RELATION_TYPE, + reverse_type=DashboardTable.TABLE_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + for table_id in self._table_ids: + m = re.match(r'([^./]+)://([^./]+)\.([^./]+)\/([^./]+)', table_id) + if m: + yield RDSDashboardTable( + dashboard_rk=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + table_rk=TableMetadata.TABLE_KEY_FORMAT.format( + db=m.group(1), + cluster=m.group(2), + schema=m.group(3), + tbl=m.group(4) + ) + ) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + pass + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + for table_id in self._table_ids: + key = AtlasTableKey(table_id) + + table_relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.table_dashboard, + entityType1=AtlasTableTypes.table, + entityQualifiedName1=key.qualified_name, + entityType2=AtlasDashboardTypes.metadata, + entityQualifiedName2=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=self._product, + cluster=self._cluster, + dashboard_group=self._dashboard_group_id, + dashboard_name=self._dashboard_id + ), + attributes={} + ) + yield table_relationship + + def __repr__(self) -> str: + return f'DashboardTable({self._dashboard_group_id!r}, {self._dashboard_id!r}, ' \ + f'{self._product!r}, {self._cluster!r}, ({",".join(self._table_ids)!r}))' diff --git a/databuilder/databuilder/models/dashboard/dashboard_usage.py b/databuilder/databuilder/models/dashboard/dashboard_usage.py new file mode 100644 index 0000000000..eed56896a0 --- /dev/null +++ b/databuilder/databuilder/models/dashboard/dashboard_usage.py @@ -0,0 +1,69 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Optional, Union, +) + +from amundsen_rds.models import RDSModel +from amundsen_rds.models.user import User as RDSUser + +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.usage.usage import Usage + + +class DashboardUsage(Usage): + """ + A model that encapsulate Dashboard usage between Dashboard and User + """ + + def __init__(self, + dashboard_group_id: Optional[str], + dashboard_id: Optional[str], + email: str, + view_count: int, + should_create_user_node: Optional[bool] = False, + product: Optional[str] = '', + cluster: Optional[str] = 'gold', + **kwargs: Any + ) -> None: + """ + :param dashboard_group_id: + :param dashboard_id: + :param email: + :param view_count: + :param should_create_user_node: Enable this if it is fine to create/update User node with only with email + address. Please be advised that other fields will be emptied. Current use case is to create anonymous user. + For example, Mode dashboard does not provide which user viewed the dashboard and anonymous user can be used + to show the usage. + :param product: + :param cluster: + :param kwargs: + """ + self._should_create_user_node = bool(should_create_user_node) + Usage.__init__( + self, + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + product=product, + cluster=cluster, + dashboard_group=dashboard_group_id, + dashboard_name=dashboard_id + ), + user_email=email, + read_count=view_count, + ) + + # override superclass for customized _should_create_user_node behavior + def create_next_node(self) -> Union[GraphNode, None]: + if self._should_create_user_node: + return super().create_next_node() + return None + + # override superclass for customized _should_create_user_node behavior + def create_next_record(self) -> Union[RDSModel, None]: + rec = super().create_next_record() + if isinstance(rec, RDSUser) and not self._should_create_user_node: + rec = super().create_next_record() + return rec diff --git a/databuilder/databuilder/models/dashboard_elasticsearch_document.py b/databuilder/databuilder/models/dashboard_elasticsearch_document.py new file mode 100644 index 0000000000..7fe97a27de --- /dev/null +++ b/databuilder/databuilder/models/dashboard_elasticsearch_document.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + List, Optional, Union, +) + +from databuilder.models.elasticsearch_document import ElasticsearchDocument + + +class DashboardESDocument(ElasticsearchDocument): + """ + Schema for the ES dashboard ES document + """ + + def __init__(self, + group_name: str, + name: str, + description: Union[str, None], + total_usage: int, + product: Optional[str] = '', + cluster: Optional[str] = '', + group_description: Optional[str] = None, + query_names: Union[List[str], None] = None, + chart_names: Optional[List[str]] = None, + group_url: Optional[str] = None, + url: Optional[str] = None, + uri: Optional[str] = None, + last_successful_run_timestamp: Optional[int] = None, + tags: Optional[List[str]] = None, + badges: Optional[List[str]] = None, + ) -> None: + self.group_name = group_name + self.name = name + self.description = description + self.cluster = cluster + self.product = product + self.group_url = group_url + self.url = url + self.uri = uri + self.last_successful_run_timestamp = last_successful_run_timestamp + self.total_usage = total_usage + self.group_description = group_description + self.query_names = query_names + self.chart_names = chart_names + self.tags = tags + self.badges = badges diff --git a/databuilder/databuilder/models/description_metadata.py b/databuilder/databuilder/models/description_metadata.py new file mode 100644 index 0000000000..edf5bbafae --- /dev/null +++ b/databuilder/databuilder/models/description_metadata.py @@ -0,0 +1,158 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Iterator, Optional, Union, +) + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable + +DESCRIPTION_NODE_LABEL_VAL = 'Description' +DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL + + +class DescriptionMetadata(GraphSerializable, AtlasSerializable): + DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL + PROGRAMMATIC_DESCRIPTION_NODE_LABEL = 'Programmatic_Description' + DESCRIPTION_KEY_FORMAT = '{description}' + DESCRIPTION_TEXT = 'description' + DESCRIPTION_SOURCE = 'description_source' + + DESCRIPTION_RELATION_TYPE = 'DESCRIPTION' + INVERSE_DESCRIPTION_RELATION_TYPE = 'DESCRIPTION_OF' + + # The default editable source. + DEFAULT_SOURCE = "description" + + def __init__(self, + text: Optional[str], + source: str = DEFAULT_SOURCE, + description_key: Optional[str] = None, + start_label: Optional[str] = None, # Table, Column, Schema, Type_Metadata + start_key: Optional[str] = None, + ): + """ + :param source: The unique source of what is populating this description. + :param text: the description text. Markdown supported. + """ + self.source = source + self.text = text + # There are so many dependencies on Description node, that it is probably easier to just separate the rest out. + if self.source == self.DEFAULT_SOURCE: + self.label = self.DESCRIPTION_NODE_LABEL + else: + self.label = self.PROGRAMMATIC_DESCRIPTION_NODE_LABEL + + self.start_label = start_label + self.start_key = start_key + self.description_key = description_key or self.get_description_default_key(start_key) + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + + def __eq__(self, other: Any) -> bool: + if isinstance(other, DescriptionMetadata): + return (self.text == other.text and + self.source == other.source and + self.description_key == other.description_key and + self.start_label == other.start_label and + self.start_key == self.start_key) + return False + + @staticmethod + def create_description_metadata(text: Union[None, str], + source: Optional[str] = DEFAULT_SOURCE, + description_key: Optional[str] = None, + start_label: Optional[str] = None, # Table, Column, Schema + start_key: Optional[str] = None, + ) -> Optional['DescriptionMetadata']: + # We do not want to create a node if there is no description text! + if text is None: + return None + description_node = DescriptionMetadata(text=text, + source=source or DescriptionMetadata.DEFAULT_SOURCE, + description_key=description_key, + start_label=start_label, + start_key=start_key) + return description_node + + def get_description_id(self) -> str: + if self.source == self.DEFAULT_SOURCE: + return "_description" + else: + return "_" + self.source + "_description" + + def get_description_default_key(self, start_key: Optional[str]) -> Optional[str]: + return f'{start_key}/{self.get_description_id()}' if start_key else None + + def get_node(self, node_key: str) -> GraphNode: + node = GraphNode( + key=node_key, + label=self.label, + attributes={ + DescriptionMetadata.DESCRIPTION_SOURCE: self.source, + DescriptionMetadata.DESCRIPTION_TEXT: self.text + } + ) + return node + + def get_relation(self, + start_node: str, + start_key: str, + end_key: str, + ) -> GraphRelationship: + relationship = GraphRelationship( + start_label=start_node, + start_key=start_key, + end_label=self.label, + end_key=end_key, + type=DescriptionMetadata.DESCRIPTION_RELATION_TYPE, + reverse_type=DescriptionMetadata.INVERSE_DESCRIPTION_RELATION_TYPE, + attributes={} + ) + return relationship + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + if not self.description_key: + raise Exception('Required description node key cannot be None') + yield self.get_node(self.description_key) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + if not self.start_label: + raise Exception('Required relation start node label cannot be None') + if not self.start_key: + raise Exception('Required relation start key cannot be None') + if not self.description_key: + raise Exception('Required relation end key cannot be None') + yield self.get_relation( + start_node=self.start_label, + start_key=self.start_key, + end_key=self.description_key + ) + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + pass + + def __repr__(self) -> str: + return f'DescriptionMetadata({self.source!r}, {self.text!r})' diff --git a/databuilder/databuilder/models/elasticsearch_document.py b/databuilder/databuilder/models/elasticsearch_document.py new file mode 100644 index 0000000000..903790835e --- /dev/null +++ b/databuilder/databuilder/models/elasticsearch_document.py @@ -0,0 +1,22 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +from abc import ABCMeta + + +class ElasticsearchDocument: + """ + Base class for ElasticsearchDocument + Each different resource ESDoc will be a subclass + """ + __metaclass__ = ABCMeta + + def to_json(self) -> str: + """ + Convert object to json + :return: + """ + obj_dict = {k: v for k, v in sorted(self.__dict__.items())} + data = json.dumps(obj_dict) + "\n" + return data diff --git a/databuilder/databuilder/models/es_last_updated.py b/databuilder/databuilder/models/es_last_updated.py new file mode 100644 index 0000000000..bacf1b927c --- /dev/null +++ b/databuilder/databuilder/models/es_last_updated.py @@ -0,0 +1,77 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator, Union + +from amundsen_rds.models import RDSModel +from amundsen_rds.models.updated_timestamp import UpdatedTimestamp as RDSUpdatedTimestamp + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable + + +class ESLastUpdated(GraphSerializable, TableSerializable): + """ + Data model to keep track the last updated timestamp for + datastore and es. + """ + + LABEL = 'Updatedtimestamp' + KEY = 'amundsen_updated_timestamp' + LATEST_TIMESTAMP = 'latest_timestamp' + + def __init__(self, + timestamp: int, + ) -> None: + """ + :param timestamp: epoch for latest updated timestamp for neo4j an es + """ + self.timestamp = timestamp + self._node_iter = self._create_node_iterator() + self._rel_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + + def create_next_node(self) -> Union[GraphNode, None]: + """ + Will create an orphan node for last updated timestamp. + """ + try: + return next(self._node_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create an es_updated_timestamp node + """ + node = GraphNode( + key=ESLastUpdated.KEY, + label=ESLastUpdated.LABEL, + attributes={ + ESLastUpdated.LATEST_TIMESTAMP: self.timestamp + } + ) + yield node + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._rel_iter) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + return + yield + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = RDSUpdatedTimestamp(rk=ESLastUpdated.KEY, + latest_timestamp=self.timestamp) + yield record diff --git a/databuilder/databuilder/models/feature/__init__.py b/databuilder/databuilder/models/feature/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/feature/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/feature/feature_elasticsearch_document.py b/databuilder/databuilder/models/feature/feature_elasticsearch_document.py new file mode 100644 index 0000000000..5932c32b80 --- /dev/null +++ b/databuilder/databuilder/models/feature/feature_elasticsearch_document.py @@ -0,0 +1,39 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from databuilder.models.elasticsearch_document import ElasticsearchDocument + + +class FeatureESDocument(ElasticsearchDocument): + """ + Schema for the Feature ES document + """ + + def __init__(self, + feature_group: str, + feature_name: str, + version: str, + key: str, + total_usage: int, + status: Optional[str] = None, + entity: Optional[str] = None, + description: Optional[str] = None, + availability: Optional[List[str]] = None, + badges: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + last_updated_timestamp: Optional[int] = None, + ) -> None: + self.feature_group = feature_group + self.feature_name = feature_name + self.version = version + self.key = key + self.total_usage = total_usage + self.status = status + self.entity = entity + self.description = description + self.availability = availability + self.badges = badges + self.tags = tags + self.last_updated_timestamp = last_updated_timestamp diff --git a/databuilder/databuilder/models/feature/feature_generation_code.py b/databuilder/databuilder/models/feature/feature_generation_code.py new file mode 100644 index 0000000000..a1cb4d59e2 --- /dev/null +++ b/databuilder/databuilder/models/feature/feature_generation_code.py @@ -0,0 +1,97 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterator, Optional, Union, +) + +from databuilder.models.feature.feature_metadata import FeatureMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable + + +# FeatureGenerationCode allows ingesting as text the generation code - whether sql or not - +# which was used to create a feature. Unlike the Query node for Dashboards, it has no inherent +# concept of name, url, id, or hierarchical structure. This allows for maximum flexibility to +# ingest generation code regardless of source. +class FeatureGenerationCode(GraphSerializable): + NODE_LABEL = 'Feature_Generation_Code' + + TEXT_ATTR = 'text' + LAST_EXECUTED_TIMESTAMP_ATTR = 'last_executed_timestamp' + SOURCE_ATTR = 'source' + + FEATURE_GENCODE_RELATION_TYPE = 'GENERATION_CODE' + GENCODE_FEATURE_RELATION_TYPE = 'GENERATION_CODE_OF' + + def __init__(self, + feature_group: str, + feature_name: str, + feature_version: str, + text: str, + source: Optional[str] = None, + last_executed_timestamp: Optional[int] = None, + **kwargs: Any + ) -> None: + + self.feature_group = feature_group + self.feature_name = feature_name + self.feature_version = feature_version + self.text = text + self.source = source + self.last_executed_timestamp = last_executed_timestamp + + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + + def __repr__(self) -> str: + return f'Feature_Generation_Code({self.feature_group!r}, {self.feature_name!r}, {self.feature_version!r}, ' \ + f'{self.text!r}, {self.source!r}, {self.last_executed_timestamp!r})' + + def _get_feature_key(self) -> str: + return FeatureMetadata.KEY_FORMAT.format(feature_group=self.feature_group, + name=self.feature_name, + version=self.feature_version) + + def _get_generation_code_key(self) -> str: + return f'{self._get_feature_key()}/_generation_code' + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + attrs: Dict[str, Any] = { + FeatureGenerationCode.TEXT_ATTR: self.text, + } + if self.last_executed_timestamp: + attrs[FeatureGenerationCode.LAST_EXECUTED_TIMESTAMP_ATTR] = self.last_executed_timestamp + + if self.source: + attrs[FeatureGenerationCode.SOURCE_ATTR] = self.source + + yield GraphNode( + key=self._get_generation_code_key(), + label=FeatureGenerationCode.NODE_LABEL, + attributes=attrs, + ) + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=FeatureMetadata.NODE_LABEL, + end_label=FeatureGenerationCode.NODE_LABEL, + start_key=self._get_feature_key(), + end_key=self._get_generation_code_key(), + type=FeatureGenerationCode.FEATURE_GENCODE_RELATION_TYPE, + reverse_type=FeatureGenerationCode.GENCODE_FEATURE_RELATION_TYPE, + attributes={}, + ) diff --git a/databuilder/databuilder/models/feature/feature_metadata.py b/databuilder/databuilder/models/feature/feature_metadata.py new file mode 100644 index 0000000000..80c59ebc26 --- /dev/null +++ b/databuilder/databuilder/models/feature/feature_metadata.py @@ -0,0 +1,214 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Any, Dict, Iterator, List, Optional, Set, +) + +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import ( + TableMetadata, TagMetadata, _format_as_list, +) + + +class FeatureMetadata(GraphSerializable): + """ + Base feature metadata. + + It implements GraphSerializable (TODO: implement TableSerializable) + so that it can be serialized to produce Feature, Feature_Group, Tag, + Database, Description and relations between those. + """ + + NODE_LABEL = 'Feature' + KEY_FORMAT = '{feature_group}/{name}/{version}' + + NAME_ATTR = 'name' + VERSION_ATTR = 'version' + STATUS_ATTR = 'status' + ENTITY_ATTR = 'entity' + DATA_TYPE_ATTR = 'data_type' + CREATED_TIMESTAMP_ATTR = 'created_timestamp' + LAST_UPDATED_TIMESTAMP_ATTR = 'last_updated_timestamp' + + GROUP_NODE_LABEL = 'Feature_Group' + GROUP_KEY_FORMAT = '{feature_group}' + GROUP_FEATURE_RELATION_TYPE = 'GROUPS' + FEATURE_GROUP_RELATION_TYPE = 'GROUPED_BY' + + FEATURE_DATABASE_RELATION_TYPE = 'FEATURE_AVAILABLE_IN' + DATABASE_FEATURE_RELATION_TYPE = 'AVAILABLE_FEATURE' + + processed_feature_group_keys: Set[str] = set() + processed_database_keys: Set[str] = set() + + def __init__(self, + feature_group: str, + name: str, + version: str, + status: Optional[str] = None, + entity: Optional[str] = None, + data_type: Optional[str] = None, + availability: Optional[List[str]] = None, # list of databases + description: Optional[str] = None, + tags: Optional[List[str]] = None, + created_timestamp: Optional[int] = None, + last_updated_timestamp: Optional[int] = None, + **kwargs: Any + ) -> None: + + self.feature_group = feature_group + self.name = name + self.version = version + self.status = status + # what business entity the feature is about, e.g. 'Buyer', 'Ride', 'Listing', etc. + self.entity = entity + self.data_type = data_type + self.availability = _format_as_list(availability) + self.description = DescriptionMetadata.create_description_metadata(text=description) + self.tags = _format_as_list(tags) + self.created_timestamp = created_timestamp + self.last_updated_timestamp = last_updated_timestamp + + self._node_iterator = self._create_next_node() + self._relation_iterator = self._create_next_relation() + + def __repr__(self) -> str: + return f'FeatureMetadata(' \ + f'{self.feature_group!r}, {self.name!r}, {self.version!r}, {self.status!r}, ' \ + f'{self.entity!r}, {self.data_type!r}, {self.availability!r}, {self.description!r}, ' \ + f'{self.tags!r}, {self.created_timestamp!r}, {self.last_updated_timestamp!r})' + + def _get_feature_key(self) -> str: + return FeatureMetadata.KEY_FORMAT.format(feature_group=self.feature_group, + name=self.name, + version=self.version) + + def _get_feature_group_key(self) -> str: + return FeatureMetadata.GROUP_KEY_FORMAT.format(feature_group=self.feature_group) + + def create_next_node(self) -> Optional[GraphNode]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _get_feature_node_attributes(self) -> Dict[str, Any]: + feature_node_attrs: Dict[str, Any] = { + FeatureMetadata.NAME_ATTR: self.name, + } + if self.version: + feature_node_attrs[FeatureMetadata.VERSION_ATTR] = self.version + + if self.status: + feature_node_attrs[FeatureMetadata.STATUS_ATTR] = self.status + + if self.entity: + feature_node_attrs[FeatureMetadata.ENTITY_ATTR] = self.entity + + if self.data_type: + feature_node_attrs[FeatureMetadata.DATA_TYPE_ATTR] = self.data_type + + if self.created_timestamp: + feature_node_attrs[FeatureMetadata.CREATED_TIMESTAMP_ATTR] = self.created_timestamp + + if self.last_updated_timestamp: + feature_node_attrs[FeatureMetadata.LAST_UPDATED_TIMESTAMP_ATTR] = self.last_updated_timestamp + + return feature_node_attrs + + def _create_next_node(self) -> Iterator[GraphNode]: + yield GraphNode( + key=self._get_feature_key(), + label=FeatureMetadata.NODE_LABEL, + attributes=self._get_feature_node_attributes() + ) + + if self.feature_group: + fg = GraphNode( + key=self._get_feature_group_key(), + label=FeatureMetadata.GROUP_NODE_LABEL, + attributes={FeatureMetadata.NAME_ATTR: self.feature_group}, + ) + if fg.key not in FeatureMetadata.processed_feature_group_keys: + FeatureMetadata.processed_feature_group_keys.add(fg.key) + yield fg + + if self.description: + yield self.description.get_node( + node_key=self.description.get_description_default_key( # type: ignore + start_key=self._get_feature_key()), + ) + + for database in self.availability: + db = GraphNode( + key=TableMetadata.DATABASE_KEY_FORMAT.format(db=database), + label=TableMetadata.DATABASE_NODE_LABEL, + attributes={ + FeatureMetadata.NAME_ATTR: database + } + ) + if db.key not in FeatureMetadata.processed_database_keys: + FeatureMetadata.processed_database_keys.add(db.key) + yield db + + for tag_value in self.tags: + yield TagMetadata(name=tag_value).get_node() + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_next_relation(self) -> Iterator[GraphRelationship]: + # Feature <> Feature group + if self.feature_group: + yield GraphRelationship( + start_label=FeatureMetadata.NODE_LABEL, + end_label=FeatureMetadata.GROUP_NODE_LABEL, + start_key=self._get_feature_key(), + end_key=self._get_feature_group_key(), + type=FeatureMetadata.FEATURE_GROUP_RELATION_TYPE, + reverse_type=FeatureMetadata.GROUP_FEATURE_RELATION_TYPE, + attributes={} + ) + + # Feature <> Description + if self.description: + yield GraphRelationship( + start_label=FeatureMetadata.NODE_LABEL, + end_label=DescriptionMetadata.DESCRIPTION_NODE_LABEL, + start_key=self._get_feature_key(), + end_key=self.description.get_description_default_key(start_key=self._get_feature_key()), + type=DescriptionMetadata.DESCRIPTION_RELATION_TYPE, + reverse_type=DescriptionMetadata.INVERSE_DESCRIPTION_RELATION_TYPE, + attributes={} + ) + + # Feature <> Database + for database in self.availability: + yield GraphRelationship( + start_label=FeatureMetadata.NODE_LABEL, + end_label=TableMetadata.DATABASE_NODE_LABEL, + start_key=self._get_feature_key(), + end_key=TableMetadata.DATABASE_KEY_FORMAT.format(db=database), + type=FeatureMetadata.FEATURE_DATABASE_RELATION_TYPE, + reverse_type=FeatureMetadata.DATABASE_FEATURE_RELATION_TYPE, + attributes={} + ) + + # Feature <> Tag + for tag in self.tags: + yield GraphRelationship( + start_label=FeatureMetadata.NODE_LABEL, + end_label=TagMetadata.TAG_NODE_LABEL, + start_key=self._get_feature_key(), + end_key=TagMetadata.get_tag_key(tag), + type=TagMetadata.ENTITY_TAG_RELATION_TYPE, + reverse_type=TagMetadata.TAG_ENTITY_RELATION_TYPE, + attributes={} + ) diff --git a/databuilder/databuilder/models/feature/feature_watermark.py b/databuilder/databuilder/models/feature/feature_watermark.py new file mode 100644 index 0000000000..25c7c3393c --- /dev/null +++ b/databuilder/databuilder/models/feature/feature_watermark.py @@ -0,0 +1,85 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator, Union + +from databuilder.models.feature.feature_metadata import FeatureMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.timestamp.timestamp_constants import TIMESTAMP_PROPERTY + + +# Unlike Watermark, which concerns itself with table implementation specifics (like partition), +# Feature_Watermark is more general and does not care how the feature is stored. +class FeatureWatermark(GraphSerializable): + """ + Represents the high and low timestamp of data in a Feature. + """ + NODE_LABEL = 'Feature_Watermark' + + TYPE_ATTR = 'watermark_type' + + WATERMARK_FEATURE_RELATION = 'BELONG_TO_FEATURE' + FEATURE_WATERMARK_RELATION = 'WATERMARK' + + def __init__(self, + feature_group: str, + feature_name: str, + feature_version: str, + timestamp: int, + wm_type: str = 'high_watermark', + ) -> None: + self.feature_group = feature_group + self.feature_name = feature_name + self.feature_version = feature_version + self.timestamp = timestamp + self.wm_type = wm_type + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + + def __repr__(self) -> str: + return f'Feature_Watermark({self.wm_type!r}, {self.timestamp!r}, {self.feature_group!r}, ' \ + f'{self.feature_name!r}, {self.feature_version!r})' + + def _get_feature_key(self) -> str: + return FeatureMetadata.KEY_FORMAT.format(feature_group=self.feature_group, + name=self.feature_name, + version=self.feature_version) + + def _get_watermark_key(self) -> str: + return f'{self._get_feature_key()}/{self.wm_type}' + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + yield GraphNode( + key=self._get_watermark_key(), + label=FeatureWatermark.NODE_LABEL, + attributes={ + TIMESTAMP_PROPERTY: self.timestamp, + FeatureWatermark.TYPE_ATTR: self.wm_type, + } + ) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_key=self._get_feature_key(), + start_label=FeatureMetadata.NODE_LABEL, + end_key=self._get_watermark_key(), + end_label=FeatureWatermark.NODE_LABEL, + type=FeatureWatermark.FEATURE_WATERMARK_RELATION, + reverse_type=FeatureWatermark.WATERMARK_FEATURE_RELATION, + attributes={} + ) diff --git a/databuilder/databuilder/models/graph_node.py b/databuilder/databuilder/models/graph_node.py new file mode 100644 index 0000000000..52c9ed6f68 --- /dev/null +++ b/databuilder/databuilder/models/graph_node.py @@ -0,0 +1,13 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +GraphNode = namedtuple( + 'GraphNode', + [ + 'key', + 'label', + 'attributes' + ] +) diff --git a/databuilder/databuilder/models/graph_relationship.py b/databuilder/databuilder/models/graph_relationship.py new file mode 100644 index 0000000000..868c963a27 --- /dev/null +++ b/databuilder/databuilder/models/graph_relationship.py @@ -0,0 +1,17 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +GraphRelationship = namedtuple( + 'GraphRelationship', + [ + 'start_label', + 'end_label', + 'start_key', + 'end_key', + 'type', + 'reverse_type', + 'attributes' + ] +) diff --git a/databuilder/databuilder/models/graph_serializable.py b/databuilder/databuilder/models/graph_serializable.py new file mode 100644 index 0000000000..8319e03bfe --- /dev/null +++ b/databuilder/databuilder/models/graph_serializable.py @@ -0,0 +1,91 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Union # noqa: F401 + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship + +NODE_KEY = 'KEY' +NODE_LABEL = 'LABEL' + +RELATION_START_KEY = 'START_KEY' +RELATION_START_LABEL = 'START_LABEL' +RELATION_END_KEY = 'END_KEY' +RELATION_END_LABEL = 'END_LABEL' +RELATION_TYPE = 'TYPE' +RELATION_REVERSE_TYPE = 'REVERSE_TYPE' + + +class GraphSerializable(object, metaclass=abc.ABCMeta): + """ + A Serializable abstract class asks subclass to implement next node or + next relation in dict form so that it can be serialized to CSV file. + + Any model class that needs to be pushed to a graph database should inherit this class. + """ + + def __init__(self) -> None: + pass + + @abc.abstractmethod + def create_next_node(self) -> Union[GraphNode, None]: + """ + Creates GraphNode the process that consumes this class takes the output + serializes to the desired graph database. + + :return: a GraphNode or None if no more records to serialize + """ + raise NotImplementedError + + @abc.abstractmethod + def create_next_relation(self) -> Union[GraphRelationship, None]: + """ + Creates GraphRelationship the process that consumes this class takes the output + serializes to the desired graph database. + + :return: a GraphRelationship or None if no more record to serialize + """ + raise NotImplementedError + + def next_node(self) -> Union[GraphNode, None]: + node_dict = self.create_next_node() + if not node_dict: + return None + + self._validate_node(node_dict) + return node_dict + + def next_relation(self) -> Union[GraphRelationship, None]: + relation_dict = self.create_next_relation() + if not relation_dict: + return None + + self._validate_relation(relation_dict) + return relation_dict + + def _validate_node(self, node: GraphNode) -> None: + node_id, node_label, _ = node + + if node_id is None: + raise RuntimeError('Required header missing. Required attributes id and label , Missing: id') + + if node_label is None: + raise RuntimeError('Required header missing. Required attributes id and label , Missing: label') + + self._validate_label_value(node_label) + + def _validate_relation(self, relation: GraphRelationship) -> None: + self._validate_label_value(relation.start_label) + self._validate_label_value(relation.end_label) + self._validate_relation_type_value(relation.type) + self._validate_relation_type_value(relation.reverse_type) + + def _validate_relation_type_value(self, value: str) -> None: + if not value.isupper(): + raise RuntimeError(f'TYPE needs to be upper case: {value}') + + def _validate_label_value(self, value: str) -> None: + if not value.istitle(): + raise RuntimeError(f'LABEL should only have upper case character on its first one: {value}') diff --git a/databuilder/databuilder/models/owner.py b/databuilder/databuilder/models/owner.py new file mode 100644 index 0000000000..9d34209123 --- /dev/null +++ b/databuilder/databuilder/models/owner.py @@ -0,0 +1,167 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, List, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasCommonTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardOwner as RDSDashboardOwner +from amundsen_rds.models.table import TableOwner as RDSTableOwner +from amundsen_rds.models.user import User as RDSUser + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.owner_constants import OWNER_OF_OBJECT_RELATION_TYPE, OWNER_RELATION_TYPE +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.models.user import User +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class Owner(GraphSerializable, TableSerializable, AtlasSerializable): + LABELS_PERMITTED_TO_HAVE_OWNER = ['Table', 'Dashboard', 'Feature'] + + def __init__(self, + start_label: str, + start_key: str, + owner_emails: Union[List, str], + ) -> None: + if start_label not in Owner.LABELS_PERMITTED_TO_HAVE_OWNER: + raise Exception(f'owners for {start_label} are not supported') + self.start_label = start_label + self.start_key = start_key + if isinstance(owner_emails, str): + owner_emails = owner_emails.split(',') + self.owner_emails = [email.strip().lower() for email in owner_emails] + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def __repr__(self) -> str: + return f'Owner({self.start_label!r}, {self.start_key!r}, {self.owner_emails!r})' + + def create_next_node(self) -> Optional[GraphNode]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + for email in self.owner_emails: + if email: + yield GraphNode( + key=User.get_user_model_key(email=email), + label=User.USER_NODE_LABEL, + attributes={ + User.USER_NODE_EMAIL: email, + } + ) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + for email in self.owner_emails: + if email: + yield GraphRelationship( + start_label=self.start_label, + start_key=self.start_key, + end_label=User.USER_NODE_LABEL, + end_key=User.get_user_model_key(email=email), + type=OWNER_RELATION_TYPE, + reverse_type=OWNER_OF_OBJECT_RELATION_TYPE, + attributes={} + ) + + def _create_record_iterator(self) -> Iterator[RDSModel]: + for email in self.owner_emails: + if email: + user_record = RDSUser( + rk=User.get_user_model_key(email=email), + email=email + ) + yield user_record + + if self.start_label == TableMetadata.TABLE_NODE_LABEL: + yield RDSTableOwner( + table_rk=self.start_key, + user_rk=User.get_user_model_key(email=email), + ) + elif self.start_label == DashboardMetadata.DASHBOARD_NODE_LABEL: + yield RDSDashboardOwner( + dashboard_rk=self.start_key, + user_rk=User.get_user_model_key(email=email) + ) + else: + raise Exception(f'{self.start_label}<>Owner relationship is not table serializable') + + def _create_atlas_owner_entity(self, owner: str) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, owner), + ('email', owner) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.user, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def _create_atlas_owner_relation(self, owner: str) -> AtlasRelationship: + table_relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.resource_owner, + entityType1=AtlasCommonTypes.data_set, + entityQualifiedName1=self.start_key, + entityType2=AtlasCommonTypes.user, + entityQualifiedName2=User.get_user_model_key(email=owner), + attributes={} + ) + + return table_relationship + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + for owner in self.owner_emails: + if owner: + yield self._create_atlas_owner_entity(owner) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + for owner in self.owner_emails: + if owner: + yield self._create_atlas_owner_relation(owner) diff --git a/databuilder/databuilder/models/owner_constants.py b/databuilder/databuilder/models/owner_constants.py new file mode 100644 index 0000000000..c5f5ed5a1c --- /dev/null +++ b/databuilder/databuilder/models/owner_constants.py @@ -0,0 +1,6 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +OWNER_RELATION_TYPE = 'OWNER' +OWNER_OF_OBJECT_RELATION_TYPE = 'OWNER_OF' diff --git a/databuilder/databuilder/models/query/__init__.py b/databuilder/databuilder/models/query/__init__.py new file mode 100644 index 0000000000..9f11cb8af3 --- /dev/null +++ b/databuilder/databuilder/models/query/__init__.py @@ -0,0 +1,7 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from .query import QueryMetadata # noqa +from .query_execution import QueryExecutionsMetadata # noqa +from .query_join import QueryJoinMetadata # noqa +from .query_where import QueryWhereMetadata # noqa diff --git a/databuilder/databuilder/models/query/base.py b/databuilder/databuilder/models/query/base.py new file mode 100644 index 0000000000..e270b765f5 --- /dev/null +++ b/databuilder/databuilder/models/query/base.py @@ -0,0 +1,72 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator + +from databuilder.models.graph_serializable import GraphSerializable + + +class QueryBase(GraphSerializable): + @staticmethod + def _normalize(sql: str) -> str: + """ + Normalizes a SQL query or SQL expression. + + No checks are made to ensure that the input is valid SQL. + This is not a full normalization. The following operations are preformed: + + - Any run of whitespace characters outside of a quoted region is replaces by a single ' ' character. + - Characters outside of quoted regions are made lower case. + - If present, a trailing ';' is removed from the query. + + Note: + Making characters outside quoted regions lower case does not in general result in an equivalent SQL statement. + For example, with MySQL the case sensitivity of table names is operating system dependant. + In practice, modern systems rarely rely on case sensitivity, and since making the non-quoted regions of the + query lowercase is very helpful in identifying queries, we go ahead and do so. + + Also, this method fails to identify expressions such as `1 + 2` and `1+2`. + There are likely too many special cases in this area to make much progress without doing a proper parse. + """ + text = sql.strip() + it = iter(text) + sb = [] + for c in it: + if c.isspace(): + c = QueryBase._process_whitespace(it) + sb.append(' ') + sb.append(c.lower()) + if c in ('`', '"', "'"): + for d in QueryBase._process_quoted(it, c): + sb.append(d) + if sb[-1] == ';': + sb.pop() + return ''.join(sb) + + @staticmethod + def _process_quoted(it: Iterator[str], quote: str) -> Iterator[str]: + """ + Yields characters up to and including the first occurrence of the (non-escaped) character `quote`. + + Allows `quote` to be escaped with '\\'. + """ + p = '' + for c in it: + yield c + if c == quote and p != '\\': + break + p = c + + @staticmethod + def _process_whitespace(it: Iterator[str]) -> str: + """ + Returns the first non-whitespace character encountered. + + This should never return `None` since the query text is striped before being processed. + That is, if the current character is a whitespace character, then there remains at least one non-whitespace + character in the stream. + """ + for c in it: + if not c.isspace(): + return c + raise ValueError("Input string was not stripped!") diff --git a/databuilder/databuilder/models/query/query.py b/databuilder/databuilder/models/query/query.py new file mode 100644 index 0000000000..ead286f9ff --- /dev/null +++ b/databuilder/databuilder/models/query/query.py @@ -0,0 +1,166 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +from typing import ( + Iterator, List, Optional, +) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.query.base import QueryBase +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.user import User as UserMetadata + + +class QueryMetadata(QueryBase): + """ + Query model. This creates a Query object as well as relationships + between the Query and the Table(s) that are used within the query. + The list of TableMetadata corresponding to the tables used in the + query must be provided. Optionally, the ID of the user that executed + the query can be provided as well. + + By default, all tables and users must already exist in the database + before this QueryMetadata object in order to create the relationships. + Implementers have the option to set `yield_relation_nodes` = True + in order to create all user and table nodes on the fly at the time + that this QueryMetadata is created. + """ + NODE_LABEL = 'Query' + KEY_FORMAT = '{sql_hash}' + + # Relation between entity and query + TABLE_QUERY_RELATION_TYPE = 'HAS_QUERY' + INVERSE_TABLE_QUERY_RELATION_TYPE = 'QUERY_FOR' + + USER_QUERY_RELATION_TYPE = 'EXECUTED_QUERY' + INVERSE_USER_QUERY_RELATION_TYPE = 'EXECUTED_BY' + + # Attributes + SQL = 'sql' + TABLES = 'tables' + + def __init__( + self, + sql: str, + tables: List[TableMetadata], + clean_sql: Optional[str] = None, # + user: Optional[UserMetadata] = None, # Key for the user that executed the query + yield_relation_nodes: bool = False # Allow creation of related nodes if they do not exist + ): + """ + :param sql: Full, raw SQL for a given Query + :param tables: List of table meteadata objects corresponding to tables in the query + :param clean_sql: A modified sql that should be used to create the hash if available. This + may be used if you have a query that is run on a set schedule but the where clause has + a new date or hour value injected before the query is run. You can "clean" that value + and pass in a SQL string that corresponds to the underlying query - which would should + remain the same across executions. + :param user: The user that executed the query. + :param yield_relation_nodes: A boolean, indicating whether or not all tables and users + associated to this query should have nodes created if they do not already exist. + """ + self.sql = sql + self.clean_sql = clean_sql + self.sql_hash = self._get_sql_hash(clean_sql or sql) + self.tables = tables + self.table_keys = [tm._get_table_key() for tm in tables] + self.user = user + self.yield_relation_nodes = yield_relation_nodes + self._sql_begin = sql[:25] + '...' + self._node_iter = self._create_next_node() + self._relation_iter = self._create_relation_iterator() + + def __repr__(self) -> str: + return f'QueryMetadata(SQL: {self._sql_begin}, Tables: {self.table_keys})' + + def _get_sql_hash(self, sql: str) -> str: + """ + Generates a deterministic SQL hash. Attempts to remove any formatting from the + SQL code where possible. + """ + sql_no_fmt = self._normalize(sql) + return hashlib.md5(sql_no_fmt.encode('utf-8')).hexdigest() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + @staticmethod + def get_key(sql_hash: str) -> str: + return QueryMetadata.KEY_FORMAT.format(sql_hash=sql_hash) + + def get_key_self(self) -> str: + return QueryMetadata.get_key(self.sql_hash) + + def get_query_relations(self) -> List[GraphRelationship]: + relations = [] + for table_key in self.table_keys: + table_relation = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=table_key, + end_key=self.get_key_self(), + type=self.TABLE_QUERY_RELATION_TYPE, + reverse_type=self.INVERSE_TABLE_QUERY_RELATION_TYPE, + attributes={} + ) + relations.append(table_relation) + + if self.user: + user_relation = GraphRelationship( + start_label=UserMetadata.USER_NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.user.get_user_model_key(email=self.user.email), + end_key=self.get_key_self(), + type=self.USER_QUERY_RELATION_TYPE, + reverse_type=self.INVERSE_USER_QUERY_RELATION_TYPE, + attributes={} + ) + relations.append(user_relation) + return relations + + def _create_next_node(self) -> Iterator[GraphNode]: + """ + Create query nodes + :return: + """ + yield GraphNode( + key=self.get_key_self(), + label=self.NODE_LABEL, + attributes={ + self.SQL: self.sql + } + ) + if self.yield_relation_nodes: + for table in self.tables: + for tbl_item in table._create_next_node(): + yield tbl_item + if self.user: + usr = self.user.create_next_node() + while usr: + yield usr + usr = self.user.create_next_node() + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relations = self.get_query_relations() + for relation in relations: + yield relation + + if self.yield_relation_nodes: + for table in self.tables: + for tbl_rel in table._create_next_relation(): + yield tbl_rel + if self.user: + for usr_rel in self.user._create_relation_iterator(): + yield usr_rel diff --git a/databuilder/databuilder/models/query/query_execution.py b/databuilder/databuilder/models/query/query_execution.py new file mode 100644 index 0000000000..6d717f885e --- /dev/null +++ b/databuilder/databuilder/models/query/query_execution.py @@ -0,0 +1,136 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, Optional, Union, +) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.query.query import QueryMetadata + + +class QueryExecutionsMetadata(GraphSerializable): + """ + The Amundsen Query Executions model represents an aggregation for the number + of times a query was executed within a given time window. + + Query executions are aggregated to time window to enable easily adding and + dropping new query execution aggregations without having to maintain + all instances that a query was executed in the database. Query Executions only + contain a start time and a window duration, although the window duration is + only used for informational purposes. Amundsen does not apply validation that + query executions do not overlap, therefore, it is important that any implementation + of a query execution is able to deterministically retrieve non-overlapping queries + between data builder runs. + """ + NODE_LABEL = 'Execution' + KEY_FORMAT = '{query_key}-{start_time}' + + # Relation between entity and query + QUERY_EXECUTION_RELATION_TYPE = 'HAS_EXECUTION' + INVERSE_QUERY_EXECUTION_RELATION_TYPE = 'EXECUTION_OF' + + # Execution window ENUMs + EXECUTION_WINDOW_HOURLY = 'hourly' + EXECUTION_WINDOW_DAILY = 'daily' + EXECUTION_WINDOW_WEEKLY = 'weekly' + + # Attributes + START_TIME = 'start_time' + EXECUTION_COUNT = 'execution_count' + WINDOW_DURATION = 'window_duration' + + def __init__(self, + query_metadata: QueryMetadata, + start_time: int, + execution_count: int, + window_duration: str = EXECUTION_WINDOW_DAILY, # Purely for descriptive purposes + yield_relation_nodes: bool = False + ): + """ + :param query_metadata: The Query metadata object that this execution belongs to + :param start_time: The time the query execution window started. This should + consistently be supplied as either seconds or milliseconds since epoch. + :param execution_count: The count of the number of times this query was executed + within the window + :param window_duration: A description of the window duration, e.g. daily, hourly + :param yield_relation_nodes: A boolean, indicating whether or not the query + associated to this execution should have it's node created if it does not + already exist. + """ + self.query_metadata = query_metadata + self.start_time = start_time + self.execution_count = execution_count + self.window_duration = window_duration + self.yield_relation_nodes = yield_relation_nodes + self._node_iter = self._create_next_node() + self._relation_iter = self._create_relation_iterator() + + def __repr__(self) -> str: + return ( + f'QueryExecutionsMetadata(Query: {self.query_metadata.get_key_self()}, Start Time: {self.start_time}, ' + f'Window Duration: {self.window_duration}, Count: {self.execution_count})' + ) + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + @staticmethod + def get_key(query_key: str, start_time: Union[str, int]) -> str: + return QueryExecutionsMetadata.KEY_FORMAT.format(query_key=query_key, start_time=start_time) + + def get_key_self(self) -> str: + return QueryExecutionsMetadata.get_key(query_key=self.query_metadata.get_key_self(), start_time=self.start_time) + + def get_query_relations(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=QueryMetadata.NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.query_metadata.get_key_self(), + end_key=self.get_key_self(), + type=self.QUERY_EXECUTION_RELATION_TYPE, + reverse_type=self.INVERSE_QUERY_EXECUTION_RELATION_TYPE, + attributes={} + ) + + def _create_next_node(self) -> Iterator[GraphNode]: + """ + Create query nodes + :return: + """ + # TODO: Should query metadata yiled tables as well? Otherwise if a table does not exist + # before this script is ran then the query relatoinship will not get created. + # Ideally this relationship wouldn't be "lost" but created once the table is created as well. + yield GraphNode( + key=self.get_key_self(), + label=self.NODE_LABEL, + attributes={ + self.START_TIME: self.start_time, + self.EXECUTION_COUNT: self.execution_count, + self.WINDOW_DURATION: self.window_duration, + } + ) + if self.yield_relation_nodes and self.query_metadata: + for query_item in self.query_metadata._create_next_node(): + yield query_item + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relations = self.get_query_relations() + for relation in relations: + yield relation + + if self.yield_relation_nodes and self.query_metadata: + for query_rel in self.query_metadata._create_relation_iterator(): + yield query_rel diff --git a/databuilder/databuilder/models/query/query_join.py b/databuilder/databuilder/models/query/query_join.py new file mode 100644 index 0000000000..971686073d --- /dev/null +++ b/databuilder/databuilder/models/query/query_join.py @@ -0,0 +1,204 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator, Optional + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.query.query import QueryMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class QueryJoinMetadata(GraphSerializable): + """ + A Join clause used between two tables within a query + """ + NODE_LABEL = 'Join' + KEY_FORMAT = '{join_type}-{left_column_key}-{operator}-{right_column_key}' + + # Relation between entity and query + COLUMN_JOIN_RELATION_TYPE = 'COLUMN_JOINS_WITH' + INVERSE_COLUMN_JOIN_RELATION_TYPE = 'JOIN_OF_COLUMN' + + QUERY_JOIN_RELATION_TYPE = 'QUERY_JOINS_WITH' + INVERSE_QUERY_JOIN_RELATION_TYPE = 'JOIN_OF_QUERY' + + # Node attributes + JOIN_TYPE = 'join_type' + JOIN_OPERATOR = 'operator' + JOIN_SQL = 'join_sql' + LEFT_TABLE_KEY = 'left_table_key' + LEFT_DATABASE = 'left_database' + LEFT_CLUSTER = 'left_cluster' + LEFT_SCHEMA = 'left_schema' + LEFT_TABLE = 'left_table' + RIGHT_TABLE_KEY = 'right_table_key' + RIGHT_DATABASE = 'right_database' + RIGHT_CLUSTER = 'right_cluster' + RIGHT_SCHEMA = 'right_schema' + RIGHT_TABLE = 'right_table' + + def __init__(self, + left_table: TableMetadata, + right_table: TableMetadata, + left_column: ColumnMetadata, + right_column: ColumnMetadata, + join_type: str, + join_operator: str, + join_sql: str, + query_metadata: Optional[QueryMetadata] = None, + yield_relation_nodes: bool = False): + """ + :param left_table: The table joined on the left side of the join clause + :param right_table: The table joined on the right side of the join clause + :param left_column: The column from the left table used in the join + :param right_column: The column from the right table used in the join + :param join_type: A free form string representing the type of join, examples + include: inner join, right join, full join, etc. + :param join_operator: The operator used in the join, examples include: =, >, etc. + :param query_metadata: The Query metadata object that this where clause belongs to, this + is optional to allow creating static QueryJoinMetadata objects to show on tables + without the complexity of creating QueryMetadata + :param yield_relation_nodes: A boolean, indicating whether or not the query metadata + and tables associated to this Join should have nodes created if they does not + already exist. + """ + # For inner joins we don't want to duplicate joins if the other table + # comes first in the join clause since it produces the same effect. + # This ONLY applies to inner join and you may need to massage your data + # for join_type to have the proper value + swap_left_right = False + if join_operator == '=' and join_type == 'inner join': + tables_sorted = sorted([left_table._get_table_key(), right_table._get_table_key()]) + if tables_sorted[0] == right_table: + swap_left_right = True + + self.left_table = right_table if swap_left_right else left_table + self.right_table = left_table if swap_left_right else right_table + self.left_column = right_column if swap_left_right else left_column + self.right_column = left_column if swap_left_right else right_column + + self.join_type = join_type + self.join_operator = join_operator + self.join_sql = join_sql + self.query_metadata = query_metadata + self.yield_relation_nodes = yield_relation_nodes + self._node_iter = self._create_next_node() + self._relation_iter = self._create_relation_iterator() + + def __repr__(self) -> str: + return ( + f'QueryJoinMetadata(Left Table: {self.left_table._get_table_key()}, ' + f'Right Table: {self.left_table._get_table_key()})' + ) + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + @staticmethod + def get_key(left_column_key: str, right_column_key: str, join_type: str, operator: str) -> str: + join_no_space = join_type.replace(' ', '-') + return QueryJoinMetadata.KEY_FORMAT.format(left_column_key=left_column_key, + right_column_key=right_column_key, + join_type=join_no_space, + operator=operator) + + def get_key_self(self) -> str: + return QueryJoinMetadata.get_key(left_column_key=self.left_table._get_col_key(col=self.left_column), + right_column_key=self.right_table._get_col_key(col=self.right_column), + join_type=self.join_type, + operator=self.join_operator) + + def get_query_relations(self) -> Iterator[GraphRelationship]: + + # Left Column + yield GraphRelationship( + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.left_table._get_col_key(col=self.left_column), + end_key=self.get_key_self(), + type=self.COLUMN_JOIN_RELATION_TYPE, + reverse_type=self.INVERSE_COLUMN_JOIN_RELATION_TYPE, + attributes={} + ) + + # Right Column + yield GraphRelationship( + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.right_table._get_col_key(col=self.right_column), + end_key=self.get_key_self(), + type=self.COLUMN_JOIN_RELATION_TYPE, + reverse_type=self.INVERSE_COLUMN_JOIN_RELATION_TYPE, + attributes={} + ) + + if self.query_metadata: + yield GraphRelationship( + start_label=QueryMetadata.NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.query_metadata.get_key_self(), + end_key=self.get_key_self(), + type=self.QUERY_JOIN_RELATION_TYPE, + reverse_type=self.INVERSE_QUERY_JOIN_RELATION_TYPE, + attributes={} + ) + + def _create_next_node(self) -> Iterator[GraphNode]: + """ + Create query nodes + :return: + """ + yield GraphNode( + key=self.get_key_self(), + label=self.NODE_LABEL, + attributes={ + self.JOIN_TYPE: self.join_type, + self.JOIN_OPERATOR: self.join_operator, + self.JOIN_SQL: self.join_sql, + self.LEFT_TABLE_KEY: self.left_table._get_table_key(), + self.LEFT_DATABASE: self.left_table.database, + self.LEFT_CLUSTER: self.left_table.cluster, + self.LEFT_SCHEMA: self.left_table.schema, + self.LEFT_TABLE: self.left_table.name, + self.RIGHT_TABLE_KEY: self.right_table._get_table_key(), + self.RIGHT_DATABASE: self.right_table.database, + self.RIGHT_CLUSTER: self.right_table.cluster, + self.RIGHT_SCHEMA: self.right_table.schema, + self.RIGHT_TABLE: self.right_table.name + } + ) + + if self.yield_relation_nodes: + for l_tbl_item in self.left_table._create_next_node(): + yield l_tbl_item + for r_tbl_item in self.right_table._create_next_node(): + yield r_tbl_item + if self.query_metadata: + for query_item in self.query_metadata._create_next_node(): + yield query_item + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relations = self.get_query_relations() + for relation in relations: + yield relation + + if self.yield_relation_nodes: + for l_tbl_rel in self.left_table._create_next_relation(): + yield l_tbl_rel + for r_tbl_rel in self.right_table._create_next_relation(): + yield r_tbl_rel + if self.query_metadata: + for query_rel in self.query_metadata._create_relation_iterator(): + yield query_rel diff --git a/databuilder/databuilder/models/query/query_where.py b/databuilder/databuilder/models/query/query_where.py new file mode 100644 index 0000000000..d6bd9c9814 --- /dev/null +++ b/databuilder/databuilder/models/query/query_where.py @@ -0,0 +1,176 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +from typing import ( + Iterator, List, Optional, +) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.query.base import QueryBase +from databuilder.models.query.query import QueryMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class QueryWhereMetadata(QueryBase): + """ + A Where clause used on a query. + """ + NODE_LABEL = 'Where' + KEY_FORMAT = '{table_hash}-{where_hash}' + + # Relation between table and query + COLUMN_WHERE_RELATION_TYPE = 'USES_WHERE_CLAUSE' + INVERSE_COLUMN_WHERE_RELATION_TYPE = 'WHERE_CLAUSE_USED_ON' + + QUERY_WHERE_RELATION_TYPE = 'HAS_WHERE_CLAUSE' + INVERSE_QUERY_WHERE_RELATION_TYPE = 'WHERE_CLAUSE_OF' + + # Node attributes + WHERE_CLAUSE = 'where_clause' + LEFT_ARG = 'left_arg' + RIGHT_ARG = 'right_arg' + OPERATOR = 'operator' + ALIAS_MAPPING = 'alias_mapping' + + def __init__(self, + tables: List[TableMetadata], + where_clause: str, + left_arg: Optional[str], + right_arg: Optional[str], + operator: Optional[str], + query_metadata: Optional[QueryMetadata] = None, + yield_relation_nodes: bool = False): + """ + :param tables: List of table meteadata objects corresponding to tables in this where clause + :param where_clause: a sting representation of the SQL where clause + :param left_arg: An optional string representing the left side of the where cause, e.g. + in the clause (where x < 3), this would be "x" + :param operator: An optional string representing the operator in the where cause, e.g. + in the clause (where x < 3), this would be "<" + :param right_arg: An optional string representing the right side of the where cause, e.g. + in the clause (where x < 3), this would be "3" + :param query_metadata: The Query metadata object that this where clause belongs to, this + is optional to allow creating static QueryWhereMetadata objects to show on tables + without the complexity of creating QueryMetadata + :param yield_relation_nodes: A boolean, indicating whether or not the query metadata + and tables associated to this Where should have nodes created if they does not + already exist. + """ + self.tables = tables + self.query_metadata = query_metadata + self.where_clause = where_clause + self.left_arg = left_arg + self.right_arg = right_arg + self.operator = operator + self.yield_relation_nodes = yield_relation_nodes + self._table_hash = self._get_table_hash(self.tables) + self._where_hash = self._get_where_hash(self.where_clause) + self._node_iter = self._create_next_node() + self._relation_iter = self._create_relation_iterator() + + def __repr__(self) -> str: + tbl_str = self.tables[0]._get_table_key() + if len(self.tables) > 1: + tbl_str += f' + {len(self.tables) - 1} other tables' + return f'QueryWhereMetadata(Table: {tbl_str}, {self.where_clause[:25]})' + + def _get_table_hash(self, tables: List[TableMetadata]) -> str: + """ + Generates a unique hash for a set of tables that are associated to a where clause. Since + we do not want multiple instances of this where clause represented in the database we may + need to link mulitple tables to this where clause. We do this by creating a single, unique + key across multiple tables by concatenating all of the table keys together and creating a + hash (to shorten the value). + """ + tbl_keys = ''.join(list(sorted([t._get_table_key() for t in tables]))) + return hashlib.md5(tbl_keys.encode('utf-8')).hexdigest() + + def _get_where_hash(self, where_clause: str) -> str: + """ + Generates a unique hash for a where clause. + """ + sql_no_fmt = self._normalize(where_clause) + return hashlib.md5(sql_no_fmt.encode('utf-8')).hexdigest() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + @staticmethod + def get_key(table_hash: str, where_hash: str) -> str: + return QueryWhereMetadata.KEY_FORMAT.format(table_hash=table_hash, where_hash=where_hash) + + def get_key_self(self) -> str: + return QueryWhereMetadata.get_key(table_hash=self._table_hash, where_hash=self._where_hash) + + def get_query_relations(self) -> Iterator[GraphRelationship]: + for table in self.tables: + for col in table.columns: + yield GraphRelationship( + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=table._get_col_key(col), + end_key=self.get_key_self(), + type=self.COLUMN_WHERE_RELATION_TYPE, + reverse_type=self.INVERSE_COLUMN_WHERE_RELATION_TYPE, + attributes={} + ) + + # Optional Query to Where Clause + if self.query_metadata: + yield GraphRelationship( + start_label=QueryMetadata.NODE_LABEL, + end_label=self.NODE_LABEL, + start_key=self.query_metadata.get_key_self(), + end_key=self.get_key_self(), + type=self.QUERY_WHERE_RELATION_TYPE, + reverse_type=self.INVERSE_QUERY_WHERE_RELATION_TYPE, + attributes={} + ) + + def _create_next_node(self) -> Iterator[GraphNode]: + """ + Create query nodes + :return: + """ + yield GraphNode( + key=self.get_key_self(), + label=self.NODE_LABEL, + attributes={ + self.WHERE_CLAUSE: self.where_clause, + self.LEFT_ARG: self.left_arg, + self.RIGHT_ARG: self.right_arg, + self.OPERATOR: self.operator + } + ) + if self.yield_relation_nodes: + for table in self.tables: + for tbl_item in table._create_next_node(): + yield tbl_item + if self.query_metadata: + for query_item in self.query_metadata._create_next_node(): + yield query_item + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relations = self.get_query_relations() + for relation in relations: + yield relation + + if self.yield_relation_nodes: + for table in self.tables: + for tbl_rel in table._create_next_relation(): + yield tbl_rel + if self.query_metadata: + for query_rel in self.query_metadata._create_relation_iterator(): + yield query_rel diff --git a/databuilder/databuilder/models/report.py b/databuilder/databuilder/models/report.py new file mode 100644 index 0000000000..051167fd40 --- /dev/null +++ b/databuilder/databuilder/models/report.py @@ -0,0 +1,145 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator, Union + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasCommonTypes + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class ResourceReport(GraphSerializable, AtlasSerializable): + """ + Resource Report matching model + + Report represents a document that can be linked to any resource (like a table) in Amundsen. + + Example would be Pandas Profiling HTML report containing full advanced profile of a table. + """ + + RESOURCE_REPORT_LABEL = 'Report' + + RESOURCE_REPORT_NAME = 'name' + RESOURCE_REPORT_URL = 'url' + + REPORT_KEY_FORMAT = '{resource_uri}/_report/{report_name}' + + REPORT_RESOURCE_RELATION_TYPE = 'REFERS_TO' + RESOURCE_REPORT_RELATION_TYPE = 'HAS_REPORT' + + def __init__(self, + name: str, + url: str, + resource_uri: str, + resource_label: str, # for example 'Table' + ) -> None: + self.report_name = name + self.report_url = url + + self.resource_uri = resource_uri + self.resource_label = resource_label + + self.resource_report_key = self.get_resource_model_key() + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def get_resource_model_key(self) -> str: + return ResourceReport.REPORT_KEY_FORMAT.format(resource_uri=self.resource_uri, report_name=self.report_name) + + def create_next_node(self) -> Union[GraphNode, None]: + # creates new node + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create an application node + :return: + """ + report_node = GraphNode( + key=self.resource_report_key, + label=ResourceReport.RESOURCE_REPORT_LABEL, + attributes={ + ResourceReport.RESOURCE_REPORT_NAME: self.report_name, + ResourceReport.RESOURCE_REPORT_URL: self.report_url + } + ) + + yield report_node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relations between application and table nodes + :return: + """ + graph_relationship = GraphRelationship( + start_key=self.resource_uri, + start_label=self.resource_label, + end_key=self.resource_report_key, + end_label=ResourceReport.RESOURCE_REPORT_LABEL, + type=ResourceReport.RESOURCE_REPORT_RELATION_TYPE, + reverse_type=ResourceReport.REPORT_RESOURCE_RELATION_TYPE, + attributes={} + ) + + yield graph_relationship + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + group_attrs_mapping = [ + (AtlasCommonParams.qualified_name, self.resource_report_key), + ('name', self.report_name), + ('url', self.report_url) + ] + + entity_attrs = get_entity_attrs(group_attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.resource_report, + operation=AtlasSerializedEntityOperation.CREATE, + relationships=None, + attributes=entity_attrs, + ) + + yield entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.referenceable_report, + entityType1=self.resource_label, + entityQualifiedName1=self.resource_uri, + entityType2=AtlasCommonTypes.resource_report, + entityQualifiedName2=self.resource_report_key, + attributes={} + ) + + yield relationship diff --git a/databuilder/databuilder/models/schema/__init__.py b/databuilder/databuilder/models/schema/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/schema/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/schema/schema.py b/databuilder/databuilder/models/schema/schema.py new file mode 100644 index 0000000000..c2174d78b7 --- /dev/null +++ b/databuilder/databuilder/models/schema/schema.py @@ -0,0 +1,154 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.schema import ( + Schema as RDSSchema, SchemaDescription as RDSSchemaDescription, + SchemaProgrammaticDescription as RDSSchemaProgrammaticDescription, +) + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.schema.schema_constant import ( + SCHEMA_KEY_PATTERN_REGEX, SCHEMA_NAME_ATTR, SCHEMA_NODE_LABEL, +) +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasSerializedEntityOperation + + +class SchemaModel(GraphSerializable, TableSerializable, AtlasSerializable): + def __init__(self, + schema_key: str, + schema: str, + description: Optional[str] = None, + description_source: Optional[str] = None, + **kwargs: Any + ) -> None: + self._schema_key = schema_key + self._schema = schema + self._cluster_key = self._get_cluster_key(schema_key) + self._description = DescriptionMetadata.create_description_metadata(text=description, + source=description_source) \ + if description else None + self._node_iterator = self._create_node_iterator() + self._relation_iterator = self._create_relation_iterator() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node = GraphNode( + key=self._schema_key, + label=SCHEMA_NODE_LABEL, + attributes={ + SCHEMA_NAME_ATTR: self._schema, + } + ) + yield node + + if self._description: + yield self._description.get_node(self._get_description_node_key()) + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + schema_record = RDSSchema( + rk=self._schema_key, + name=self._schema, + cluster_rk=self._cluster_key + ) + yield schema_record + + if self._description: + if self._description.label == DescriptionMetadata.DESCRIPTION_NODE_LABEL: + yield RDSSchemaDescription( + rk=self._get_description_node_key(), + description_source=self._description.source, + description=self._description.text, + schema_rk=self._schema_key + ) + else: + yield RDSSchemaProgrammaticDescription( + rk=self._get_description_node_key(), + description_source=self._description.source, + description=self._description.text, + schema_rk=self._schema_key + ) + + def _get_description_node_key(self) -> str: + desc = self._description.get_description_id() if self._description is not None else '' + return f'{self._schema_key}/{desc}' + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + if self._description: + yield self._description.get_relation(start_node=SCHEMA_NODE_LABEL, + start_key=self._schema_key, + end_key=self._get_description_node_key()) + + def _get_cluster_key(self, schema_key: str) -> str: + schema_key_pattern = re.compile(SCHEMA_KEY_PATTERN_REGEX) + schema_key_match = schema_key_pattern.match(schema_key) + if not schema_key_match: + raise Exception(f'{schema_key} does not match the schema key pattern') + + cluster_key = schema_key_match.group(1) + return cluster_key + + def _create_atlas_schema_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._schema_key), + ('name', self._schema_key), + ('description', self._description.text if self._description else '') + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + # Since Schema cannot exist without Cluster (COMPOSITION relationship type), we assume Schema entity was created + # by different process and we only update schema description here using UPDATE operation. + entity = AtlasEntity( + typeName=AtlasTableTypes.schema, + operation=AtlasSerializedEntityOperation.UPDATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_schema_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass diff --git a/databuilder/databuilder/models/schema/schema_constant.py b/databuilder/databuilder/models/schema/schema_constant.py new file mode 100644 index 0000000000..2df9a21322 --- /dev/null +++ b/databuilder/databuilder/models/schema/schema_constant.py @@ -0,0 +1,14 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +SCHEMA_NODE_LABEL = 'Schema' + +SCHEMA_NAME_ATTR = 'name' + +SCHEMA_RELATION_TYPE = 'SCHEMA' +SCHEMA_REVERSE_RELATION_TYPE = 'SCHEMA_OF' + +DATABASE_SCHEMA_KEY_FORMAT = '{db}://{cluster}.{schema}' + +# pattern used to match a schema key, e.g., hive://gold.test_schema +SCHEMA_KEY_PATTERN_REGEX = '([a-zA-Z0-9_]+://[a-zA-Z0-9_-]+).[a-zA-Z0-9_.-]+' diff --git a/databuilder/databuilder/models/table_column_usage.py b/databuilder/databuilder/models/table_column_usage.py new file mode 100644 index 0000000000..9712ed95e7 --- /dev/null +++ b/databuilder/databuilder/models/table_column_usage.py @@ -0,0 +1,123 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterable, Iterator, Optional, Union, +) + +from amundsen_rds.models import RDSModel + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.models.usage.usage import Usage + + +class ColumnReader(Usage): + """ + Represent user's read action on a table - and eventually on a column. + """ + + def __init__(self, + database: str, + cluster: str, + schema: str, + table: str, + column: str, # not used: per-column usage not yet implemented + user_email: str, + read_count: int = 1 + ) -> None: + + Usage.__init__( + self, + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=TableMetadata.TABLE_KEY_FORMAT.format( + db=database, + cluster=cluster, + schema=schema, + tbl=table), + user_email=user_email, + read_count=read_count, + ) + + +class TableColumnUsage(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Represents an iterable of read actions. + """ + + def __init__(self, col_readers: Iterable[ColumnReader]) -> None: + self.col_readers = col_readers + + self._node_iterator = self._create_node_iterator() + self._rel_iter = self._create_rel_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + for usage in self.col_readers: + node = usage.create_next_node() + while node is not None: + yield node + node = usage.create_next_node() + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._rel_iter) + except StopIteration: + return None + + def _create_rel_iterator(self) -> Iterator[GraphRelationship]: + for usage in self.col_readers: + rel = usage.create_next_relation() + while rel is not None: + yield rel + rel = usage.create_next_relation() + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + for usage in self.col_readers: + record = usage.create_next_record() + while record is not None: + yield record + record = usage.create_next_record() + + def _create_next_atlas_entity(self) -> Iterator[Optional[AtlasEntity]]: + for usage in self.col_readers: + yield usage.create_next_atlas_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) # type: ignore + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) # type: ignore + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[Optional[AtlasRelationship]]: + for usage in self.col_readers: + yield usage.create_next_atlas_relation() + + def __repr__(self) -> str: + return f'TableColumnUsage(col_readers={self.col_readers!r})' diff --git a/databuilder/databuilder/models/table_elasticsearch_document.py b/databuilder/databuilder/models/table_elasticsearch_document.py new file mode 100644 index 0000000000..1f35335764 --- /dev/null +++ b/databuilder/databuilder/models/table_elasticsearch_document.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from databuilder.models.elasticsearch_document import ElasticsearchDocument + + +class TableESDocument(ElasticsearchDocument): + """ + Schema for the Search index document + """ + + def __init__(self, + database: str, + cluster: str, + schema: str, + name: str, + key: str, + description: str, + last_updated_timestamp: Optional[int], + column_names: List[str], + column_descriptions: List[str], + total_usage: int, + unique_usage: int, + tags: List[str], + badges: Optional[List[str]] = None, + display_name: Optional[str] = None, + schema_description: Optional[str] = None, + programmatic_descriptions: List[str] = [], + ) -> None: + self.database = database + self.cluster = cluster + self.schema = schema + self.name = name + self.display_name = display_name if display_name else f'{schema}.{name}' + self.key = key + self.description = description + # todo: use last_updated_timestamp to match the record in metadata + self.last_updated_timestamp = int(last_updated_timestamp) if last_updated_timestamp else None + self.column_names = column_names + self.column_descriptions = column_descriptions + self.total_usage = total_usage + self.unique_usage = unique_usage + # todo: will include tag_type once we have better understanding from UI flow. + self.tags = tags + self.badges = badges + self.schema_description = schema_description + self.programmatic_descriptions = programmatic_descriptions diff --git a/databuilder/databuilder/models/table_last_updated.py b/databuilder/databuilder/models/table_last_updated.py new file mode 100644 index 0000000000..0d0b582da5 --- /dev/null +++ b/databuilder/databuilder/models/table_last_updated.py @@ -0,0 +1,137 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterator, Union + +from amundsen_rds.models import RDSModel +from amundsen_rds.models.table import TableTimestamp as RDSTableTimestamp + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.models.timestamp import timestamp_constants + + +class TableLastUpdated(GraphSerializable, TableSerializable, AtlasSerializable): + # constants + LAST_UPDATED_NODE_LABEL = timestamp_constants.NODE_LABEL + LAST_UPDATED_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}/timestamp' + TIMESTAMP_PROPERTY = timestamp_constants.DEPRECATED_TIMESTAMP_PROPERTY + TIMESTAMP_NAME_PROPERTY = timestamp_constants.TIMESTAMP_NAME_PROPERTY + + TABLE_LASTUPDATED_RELATION_TYPE = timestamp_constants.LASTUPDATED_RELATION_TYPE + LASTUPDATED_TABLE_RELATION_TYPE = timestamp_constants.LASTUPDATED_REVERSE_RELATION_TYPE + + def __init__(self, + table_name: str, + last_updated_time_epoch: int, + schema: str, + db: str = 'hive', + cluster: str = 'gold' + ) -> None: + self.table_name = table_name + self.last_updated_time = int(last_updated_time_epoch) + self.schema = schema + self.db = db + self.cluster = cluster + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + + def __repr__(self) -> str: + return f"TableLastUpdated(table_name={self.table_name!r}, last_updated_time={self.last_updated_time!r}, " \ + f"schema={self.schema!r}, db={self.db!r}, cluster={self.cluster!r})" + + def create_next_node(self) -> Union[GraphNode, None]: + # creates new node + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def get_table_model_key(self) -> str: + # returns formatted string for table name + return TableMetadata.TABLE_KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + tbl=self.table_name) + + def get_last_updated_model_key(self) -> str: + # returns formatted string for last updated name + return TableLastUpdated.LAST_UPDATED_KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + tbl=self.table_name) + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create a last_updated node + :return: + """ + node = GraphNode( + key=self.get_last_updated_model_key(), + label=TableLastUpdated.LAST_UPDATED_NODE_LABEL, + attributes={ + TableLastUpdated.TIMESTAMP_PROPERTY: self.last_updated_time, + timestamp_constants.TIMESTAMP_PROPERTY: self.last_updated_time, + TableLastUpdated.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name + } + ) + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relations mapping last updated node with table node + :return: + """ + relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self.get_table_model_key(), + end_label=TableLastUpdated.LAST_UPDATED_NODE_LABEL, + end_key=self.get_last_updated_model_key(), + type=TableLastUpdated.TABLE_LASTUPDATED_RELATION_TYPE, + reverse_type=TableLastUpdated.LASTUPDATED_TABLE_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _create_record_iterator(self) -> Iterator[RDSModel]: + """ + Create a table timestamp record + :return: + """ + record = RDSTableTimestamp( + rk=self.get_last_updated_model_key(), + last_updated_timestamp=self.last_updated_time, + timestamp=self.last_updated_time, + name=timestamp_constants.TimestampName.last_updated_timestamp.name, + table_rk=self.get_table_model_key() + ) + yield record + + # Atlas automatically updates `updateTime` of an entity if it's changed (along with storing audit info what changed) + # so we don't really need to implement those methods. The reason they exist at all is so loader class doesn't fail + # if extractor extracts this info. + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + pass + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass diff --git a/databuilder/databuilder/models/table_lineage.py b/databuilder/databuilder/models/table_lineage.py new file mode 100644 index 0000000000..3bd100c172 --- /dev/null +++ b/databuilder/databuilder/models/table_lineage.py @@ -0,0 +1,243 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import ( + Iterator, List, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.column import ColumnLineage as RDSColumnLineage +from amundsen_rds.models.table import TableLineage as RDSTableLineage + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class BaseLineage(GraphSerializable, AtlasSerializable, TableSerializable): + """ + Generic Lineage Interface + """ + LABEL = 'Lineage' + ORIGIN_DEPENDENCY_RELATION_TYPE = 'HAS_DOWNSTREAM' + DEPENDENCY_ORIGIN_RELATION_TYPE = 'HAS_UPSTREAM' + + def __init__(self) -> None: + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_rel_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_next_atlas_relation() + self._record_iter = self._create_record_iterator() + + def create_next_node(self) -> Union[GraphNode, None]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + It won't create any node for this model + :return: + """ + return + yield + + @abstractmethod + def _create_rel_iterator(self) -> Iterator[GraphRelationship]: + pass + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_atlas_process_key()), + ('name', self._get_atlas_process_key()) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasTableTypes.process, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + yield entity + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_next_atlas_relation(self) -> Iterator[AtlasRelationship]: + upstream = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.lineage_upstream, + entityType1=AtlasTableTypes.process, + entityQualifiedName1=self._get_atlas_process_key(), + entityType2=self._get_atlas_entity_type(), + entityQualifiedName2=self._get_atlas_process_key(), + attributes={} + ) + + yield upstream + + for downstream_key in self.downstream_deps: # type: ignore + downstream = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.lineage_downstream, + entityType1=AtlasTableTypes.process, + entityQualifiedName1=self._get_atlas_process_key(), + entityType2=self._get_atlas_entity_type(), + entityQualifiedName2=downstream_key, + attributes={} + ) + + yield downstream + + @abstractmethod + def _get_atlas_process_key(self) -> str: + pass + + @abstractmethod + def _get_atlas_entity_type(self) -> str: + pass + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + @abstractmethod + def _create_record_iterator(self) -> Iterator[RDSModel]: + pass + + +class TableLineage(BaseLineage): + """ + Table Lineage Model. It won't create nodes but create upstream/downstream rels. + """ + + def __init__(self, + table_key: str, + downstream_deps: Optional[List] = None, # List of table keys + ) -> None: + self.table_key = table_key + # a list of downstream dependencies, each of which will follow + # the same key + self.downstream_deps = downstream_deps or [] + super().__init__() + + def _create_rel_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relations between source table and all the downstream tables + :return: + """ + for downstream_key in self.downstream_deps: + relationship = GraphRelationship( + start_key=self.table_key, + start_label=TableMetadata.TABLE_NODE_LABEL, + end_label=TableMetadata.TABLE_NODE_LABEL, + end_key=downstream_key, + type=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + reverse_type=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _get_atlas_process_key(self) -> str: + return self.table_key + + def _get_atlas_entity_type(self) -> str: + return AtlasTableTypes.table + + def _create_record_iterator(self) -> Iterator[RDSModel]: + """ + Create lineage records for source table and its all downstream tables. + :return: + """ + for downstream_key in self.downstream_deps: + record = RDSTableLineage( + table_source_rk=self.table_key, + table_target_rk=downstream_key + ) + yield record + + def __repr__(self) -> str: + return f'TableLineage({self.table_key!r})' + + +class ColumnLineage(BaseLineage): + """ + Column Lineage Model. It won't create nodes but create upstream/downstream rels. + """ + + def __init__(self, + column_key: str, + downstream_deps: Optional[List] = None, # List of column keys + ) -> None: + self.column_key = column_key + # a list of downstream dependencies, each of which will follow + # the same key + self.downstream_deps = downstream_deps or [] + super().__init__() + + def _create_rel_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relations between source column and all the downstream columns + :return: + """ + for downstream_key in self.downstream_deps: + relationship = GraphRelationship( + start_key=self.column_key, + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_key=downstream_key, + type=ColumnLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + reverse_type=ColumnLineage.DEPENDENCY_ORIGIN_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _get_atlas_process_key(self) -> str: + return self.column_key + + def _get_atlas_entity_type(self) -> str: + return AtlasTableTypes.column + + def _create_record_iterator(self) -> Iterator[RDSModel]: + """ + Create lineage records for source column and its all downstream columns. + :return: + """ + for downstream_key in self.downstream_deps: + record = RDSColumnLineage( + column_source_rk=self.column_key, + column_target_rk=downstream_key + ) + yield record + + def __repr__(self) -> str: + return f'ColumnLineage({self.column_key!r})' diff --git a/databuilder/databuilder/models/table_metadata.py b/databuilder/databuilder/models/table_metadata.py new file mode 100644 index 0000000000..85615c7baa --- /dev/null +++ b/databuilder/databuilder/models/table_metadata.py @@ -0,0 +1,835 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, +) + +from amundsen_common.utils.atlas import ( + AtlasCommonParams, AtlasCommonTypes, AtlasTableTypes, +) +from amundsen_rds.models import RDSModel +from amundsen_rds.models.cluster import Cluster as RDSCluster +from amundsen_rds.models.column import ( + ColumnBadge as RDSColumnBadge, ColumnDescription as RDSColumnDescription, TableColumn as RDSTableColumn, +) +from amundsen_rds.models.database import Database as RDSDatabase +from amundsen_rds.models.schema import Schema as RDSSchema +from amundsen_rds.models.table import ( + Table as RDSTable, TableDescription as RDSTableDescription, + TableProgrammaticDescription as RDSTableProgrammaticDescription, TableTag as RDSTableTag, +) +from amundsen_rds.models.tag import Tag as RDSTag + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.cluster import cluster_constants +from databuilder.models.description_metadata import ( # noqa: F401 + DESCRIPTION_NODE_LABEL, DESCRIPTION_NODE_LABEL_VAL, DescriptionMetadata, +) +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.schema import schema_constant +from databuilder.models.table_serializable import TableSerializable + +if TYPE_CHECKING: + from databuilder.models.type_metadata import TypeMetadata + +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +def _format_as_list(tags: Union[List, str, None]) -> List: + if tags is None: + tags = [] + if isinstance(tags, str): + tags = list(filter(None, tags.split(','))) + if isinstance(tags, list): + tags = [tag.lower().strip() for tag in tags] + return tags + + +class TagMetadata(GraphSerializable, TableSerializable, AtlasSerializable): + TAG_NODE_LABEL = 'Tag' + TAG_KEY_FORMAT = '{tag}' + TAG_TYPE = 'tag_type' + DEFAULT_TYPE = 'default' + BADGE_TYPE = 'badge' + DASHBOARD_TYPE = 'dashboard' + METRIC_TYPE = 'metric' + + TAG_ENTITY_RELATION_TYPE = 'TAG' + ENTITY_TAG_RELATION_TYPE = 'TAGGED_BY' + + def __init__(self, + name: str, + tag_type: str = 'default', + ): + self._name = name + self._tag_type = tag_type + self._nodes = self._create_node_iterator() + self._relations = self._create_relation_iterator() + self._records = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + @staticmethod + def get_tag_key(name: str) -> str: + if not name: + return '' + return TagMetadata.TAG_KEY_FORMAT.format(tag=name) + + def get_node(self) -> GraphNode: + node = GraphNode( + key=TagMetadata.get_tag_key(self._name), + label=TagMetadata.TAG_NODE_LABEL, + attributes={ + TagMetadata.TAG_TYPE: self._tag_type + } + ) + return node + + def get_record(self) -> RDSModel: + record = RDSTag( + rk=TagMetadata.get_tag_key(self._name), + tag_type=self._tag_type + ) + return record + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._nodes) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + # We don't emit any relations for Tag ingestion + try: + return next(self._relations) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._records) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + node = self.get_node() + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + return + yield + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = self.get_record() + yield record + + def _create_atlas_glossary_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._name), + ('glossary', self._tag_type), + ('term', self._name) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.tag, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def create_atlas_tag_relation(self, table_key: str) -> AtlasRelationship: + table_relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.tag, + entityType1=AtlasCommonTypes.data_set, + entityQualifiedName1=table_key, + entityType2=AtlasRelationshipTypes.tag, + entityQualifiedName2=f'glossary={self._tag_type},term={self._name}', + attributes={} + ) + + return table_relationship + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_glossary_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + +class ColumnMetadata: + COLUMN_NODE_LABEL = 'Column' + COLUMN_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{col}' + COLUMN_NAME = 'name' + COLUMN_TYPE = 'col_type' + COLUMN_ORDER = 'sort_order' + COLUMN_DESCRIPTION = 'description' + COLUMN_DESCRIPTION_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{col}/{description_id}' + + def __init__(self, + name: str, + description: Union[str, None], + col_type: str, + sort_order: int, + badges: Union[List[str], None] = None, + ) -> None: + """ + TODO: Add stats + :param name: + :param description: + :param col_type: + :param sort_order: + :param badges: Optional. Column level badges + """ + self.name = name + self.description = DescriptionMetadata.create_description_metadata(source=None, + text=description) + self.type = col_type + self.sort_order = sort_order + formatted_badges = _format_as_list(badges) + self.badges = [Badge(badge, 'column') for badge in formatted_badges] + + # The following fields are populated by the ComplexTypeTransformer + self._column_key: Optional[str] = None + self._type_metadata: Optional[TypeMetadata] = None + + def __repr__(self) -> str: + return f'ColumnMetadata({self.name!r}, {self.description!r}, {self.type!r}, ' \ + f'{self.sort_order!r}, {self.badges!r})' + + def get_column_key(self) -> Optional[str]: + return self._column_key + + def set_column_key(self, col_key: str) -> None: + self._column_key = col_key + + def get_type_metadata(self) -> Optional['TypeMetadata']: + return self._type_metadata + + def set_type_metadata(self, type_metadata: 'TypeMetadata') -> None: + self._type_metadata = type_metadata + + +class TableMetadata(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Table metadata that contains columns. It implements Neo4jCsvSerializable so that it can be serialized to produce + Table, Column and relation of those along with relationship with table and schema. Additionally, it will create + Database, Cluster, and Schema with relastionships between those. + These are being created here as it does not make much sense to have different extraction to produce this. As + database, cluster, schema would be very repititive with low cardinality, it will perform de-dupe so that publisher + won't need to publish same nodes, relationships. + + This class can be used for both table and view metadata. If it is a View, is_view=True should be passed in. + """ + TABLE_NODE_LABEL = 'Table' + TABLE_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}' + TABLE_NAME = 'name' + IS_VIEW = 'is_view' + + TABLE_DESCRIPTION_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{description_id}' + + DATABASE_NODE_LABEL = 'Database' + DATABASE_KEY_FORMAT = 'database://{db}' + DATABASE_CLUSTER_RELATION_TYPE = cluster_constants.CLUSTER_RELATION_TYPE + CLUSTER_DATABASE_RELATION_TYPE = cluster_constants.CLUSTER_REVERSE_RELATION_TYPE + + CLUSTER_NODE_LABEL = cluster_constants.CLUSTER_NODE_LABEL + CLUSTER_KEY_FORMAT = '{db}://{cluster}' + CLUSTER_SCHEMA_RELATION_TYPE = schema_constant.SCHEMA_RELATION_TYPE + SCHEMA_CLUSTER_RELATION_TYPE = schema_constant.SCHEMA_REVERSE_RELATION_TYPE + + SCHEMA_NODE_LABEL = schema_constant.SCHEMA_NODE_LABEL + SCHEMA_KEY_FORMAT = schema_constant.DATABASE_SCHEMA_KEY_FORMAT + SCHEMA_TABLE_RELATION_TYPE = 'TABLE' + TABLE_SCHEMA_RELATION_TYPE = 'TABLE_OF' + + TABLE_COL_RELATION_TYPE = 'COLUMN' + COL_TABLE_RELATION_TYPE = 'COLUMN_OF' + + TABLE_TAG_RELATION_TYPE = TagMetadata.ENTITY_TAG_RELATION_TYPE + TAG_TABLE_RELATION_TYPE = TagMetadata.TAG_ENTITY_RELATION_TYPE + + # Only for deduping database, cluster, and schema (table and column will be always processed) + serialized_nodes_keys: Set[Any] = set() + serialized_rels_keys: Set[Any] = set() + serialized_records_keys: Set[Any] = set() + + def __init__(self, + database: str, + cluster: str, + schema: str, + name: str, + description: Union[str, None], + columns: Optional[Iterable[ColumnMetadata]] = None, + is_view: bool = False, + tags: Union[List, str, None] = None, + description_source: Union[str, None] = None, + **kwargs: Any + ) -> None: + """ + :param database: + :param cluster: + :param schema: + :param name: + :param description: + :param columns: + :param is_view: Indicate whether the table is a view or not + :param tags: + :param description_source: Optional. Where the description is coming from. Used to compose unique id. + :param kwargs: Put additional attributes to the table model if there is any. + """ + self.database = database + self.cluster = cluster + self.schema = schema + self.name = name + self.description = DescriptionMetadata.create_description_metadata(text=description, source=description_source) + self.columns = columns if columns else [] + self.is_view = is_view + self.attrs: Optional[Dict[str, Any]] = None + + self.tags = _format_as_list(tags) + + if kwargs: + self.attrs = copy.deepcopy(kwargs) + + self._node_iterator = self._create_next_node() + self._relation_iterator = self._create_next_relation() + self._record_iterator = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def __repr__(self) -> str: + return f'TableMetadata({self.database!r}, {self.cluster!r}, {self.schema!r}, {self.name!r} ' \ + f'{self.description!r}, {self.columns!r}, {self.is_view!r}, {self.tags!r})' + + def _get_table_key(self) -> str: + return TableMetadata.TABLE_KEY_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.name) + + def _get_table_description_key(self, description: DescriptionMetadata) -> str: + return TableMetadata.TABLE_DESCRIPTION_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.name, + description_id=description.get_description_id()) + + def _get_database_key(self) -> str: + return TableMetadata.DATABASE_KEY_FORMAT.format(db=self.database) + + def _get_cluster_key(self) -> str: + return TableMetadata.CLUSTER_KEY_FORMAT.format(db=self.database, + cluster=self.cluster) + + def _get_schema_key(self) -> str: + return TableMetadata.SCHEMA_KEY_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema) + + def _get_col_key(self, col: ColumnMetadata) -> str: + return ColumnMetadata.COLUMN_KEY_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.name, + col=col.name, + badges=col.badges) + + def _get_col_description_key(self, + col: ColumnMetadata, + description: DescriptionMetadata) -> str: + return ColumnMetadata.COLUMN_DESCRIPTION_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.name, + col=col.name, + description_id=description.get_description_id()) + + @staticmethod + def format_tags(tags: Union[List, str, None]) -> List: + return _format_as_list(tags) + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iterator) + except StopIteration: + return None + + def _create_next_node(self) -> Iterator[GraphNode]: + yield self._create_table_node() + + if self.description: + node_key = self._get_table_description_key(self.description) + yield self.description.get_node(node_key) + + # Create the table tag nodes + if self.tags: + for tag in self.tags: + tag_node = TagMetadata(tag).get_node() + yield tag_node + + for col in self.columns: + yield from self._create_column_nodes(col) + + # Database, cluster, schema + others = [ + GraphNode( + key=self._get_database_key(), + label=TableMetadata.DATABASE_NODE_LABEL, + attributes={ + 'name': self.database + } + ), + GraphNode( + key=self._get_cluster_key(), + label=TableMetadata.CLUSTER_NODE_LABEL, + attributes={ + 'name': self.cluster + } + ), + GraphNode( + key=self._get_schema_key(), + label=TableMetadata.SCHEMA_NODE_LABEL, + attributes={ + 'name': self.schema + } + ) + ] + + for node_tuple in others: + if node_tuple.key not in TableMetadata.serialized_nodes_keys: + TableMetadata.serialized_nodes_keys.add(node_tuple.key) + yield node_tuple + + def _create_table_node(self) -> GraphNode: + table_attributes = { + TableMetadata.TABLE_NAME: self.name, + TableMetadata.IS_VIEW: self.is_view + } + if self.attrs: + for k, v in self.attrs.items(): + if k not in table_attributes: + table_attributes[k] = v + + return GraphNode( + key=self._get_table_key(), + label=TableMetadata.TABLE_NODE_LABEL, + attributes=table_attributes + ) + + def _create_column_nodes(self, col: ColumnMetadata) -> Iterator[GraphNode]: + column_node = GraphNode( + key=self._get_col_key(col), + label=ColumnMetadata.COLUMN_NODE_LABEL, + attributes={ + ColumnMetadata.COLUMN_NAME: col.name, + ColumnMetadata.COLUMN_TYPE: col.type, + ColumnMetadata.COLUMN_ORDER: col.sort_order + } + ) + yield column_node + + if col.description: + node_key = self._get_col_description_key(col, col.description) + yield col.description.get_node(node_key) + + if col.badges: + col_badge_metadata = BadgeMetadata( + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + start_key=self._get_col_key(col), + badges=col.badges) + badge_nodes = col_badge_metadata.get_badge_nodes() + for node in badge_nodes: + yield node + + type_metadata = col.get_type_metadata() + if type_metadata: + yield from type_metadata.create_node_iterator() + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iterator) + except StopIteration: + return None + + def _create_next_relation(self) -> Iterator[GraphRelationship]: + schema_table_relationship = GraphRelationship( + start_key=self._get_schema_key(), + start_label=TableMetadata.SCHEMA_NODE_LABEL, + end_key=self._get_table_key(), + end_label=TableMetadata.TABLE_NODE_LABEL, + type=TableMetadata.SCHEMA_TABLE_RELATION_TYPE, + reverse_type=TableMetadata.TABLE_SCHEMA_RELATION_TYPE, + attributes={} + ) + yield schema_table_relationship + + if self.description: + yield self.description.get_relation(TableMetadata.TABLE_NODE_LABEL, + self._get_table_key(), + self._get_table_description_key(self.description)) + + if self.tags: + for tag in self.tags: + tag_relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self._get_table_key(), + end_label=TagMetadata.TAG_NODE_LABEL, + end_key=TagMetadata.get_tag_key(tag), + type=TableMetadata.TABLE_TAG_RELATION_TYPE, + reverse_type=TableMetadata.TAG_TABLE_RELATION_TYPE, + attributes={} + ) + yield tag_relationship + + for col in self.columns: + yield from self._create_column_relations(col) + + others = [ + GraphRelationship( + start_label=TableMetadata.DATABASE_NODE_LABEL, + end_label=TableMetadata.CLUSTER_NODE_LABEL, + start_key=self._get_database_key(), + end_key=self._get_cluster_key(), + type=TableMetadata.DATABASE_CLUSTER_RELATION_TYPE, + reverse_type=TableMetadata.CLUSTER_DATABASE_RELATION_TYPE, + attributes={} + ), + GraphRelationship( + start_label=TableMetadata.CLUSTER_NODE_LABEL, + end_label=TableMetadata.SCHEMA_NODE_LABEL, + start_key=self._get_cluster_key(), + end_key=self._get_schema_key(), + type=TableMetadata.CLUSTER_SCHEMA_RELATION_TYPE, + reverse_type=TableMetadata.SCHEMA_CLUSTER_RELATION_TYPE, + attributes={} + ) + ] + + for rel_tuple in others: + if (rel_tuple.start_key, rel_tuple.end_key, rel_tuple.type) not in TableMetadata.serialized_rels_keys: + TableMetadata.serialized_rels_keys.add((rel_tuple.start_key, rel_tuple.end_key, rel_tuple.type)) + yield rel_tuple + + def _create_column_relations(self, col: ColumnMetadata) -> Iterator[GraphRelationship]: + column_relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self._get_table_key(), + end_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_key=self._get_col_key(col), + type=TableMetadata.TABLE_COL_RELATION_TYPE, + reverse_type=TableMetadata.COL_TABLE_RELATION_TYPE, + attributes={} + ) + yield column_relationship + + if col.description: + yield col.description.get_relation( + ColumnMetadata.COLUMN_NODE_LABEL, + self._get_col_key(col), + self._get_col_description_key(col, col.description) + ) + + if col.badges: + badge_metadata = BadgeMetadata(start_label=ColumnMetadata.COLUMN_NODE_LABEL, + start_key=self._get_col_key(col), + badges=col.badges) + badge_relations = badge_metadata.get_badge_relations() + for relation in badge_relations: + yield relation + + type_metadata = col.get_type_metadata() + if type_metadata: + yield from type_metadata.create_relation_iterator() + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iterator) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + # Database, Cluster, Schema + others: List[RDSModel] = [ + RDSDatabase( + rk=self._get_database_key(), + name=self.database + ), + RDSCluster( + rk=self._get_cluster_key(), + name=self.cluster, + database_rk=self._get_database_key() + ), + RDSSchema( + rk=self._get_schema_key(), + name=self.schema, + cluster_rk=self._get_cluster_key() + ) + ] + + for record in others: + if record.rk not in TableMetadata.serialized_records_keys: + TableMetadata.serialized_records_keys.add(record.rk) + yield record + + # Table + yield RDSTable( + rk=self._get_table_key(), + name=self.name, + is_view=self.is_view, + schema_rk=self._get_schema_key() + ) + + # Table description + if self.description: + description_record_key = self._get_table_description_key(self.description) + if self.description.label == DescriptionMetadata.DESCRIPTION_NODE_LABEL: + yield RDSTableDescription( + rk=description_record_key, + description_source=self.description.source, + description=self.description.text, + table_rk=self._get_table_key() + ) + else: + yield RDSTableProgrammaticDescription( + rk=description_record_key, + description_source=self.description.source, + description=self.description.text, + table_rk=self._get_table_key() + ) + + # Tag + for tag in self.tags: + tag_record = TagMetadata(tag).get_record() + yield tag_record + + table_tag_record = RDSTableTag( + table_rk=self._get_table_key(), + tag_rk=TagMetadata.get_tag_key(tag) + ) + yield table_tag_record + + # Column + for col in self.columns: + yield RDSTableColumn( + rk=self._get_col_key(col), + name=col.name, + type=col.type, + sort_order=col.sort_order, + table_rk=self._get_table_key() + ) + + if col.description: + description_record_key = self._get_col_description_key(col, col.description) + yield RDSColumnDescription( + rk=description_record_key, + description_source=col.description.source, + description=col.description.text, + column_rk=self._get_col_key(col) + ) + + if col.badges: + badge_metadata = BadgeMetadata( + start_label=ColumnMetadata.COLUMN_NODE_LABEL, + start_key=self._get_col_key(col), + badges=col.badges + ) + + badge_records = badge_metadata.get_badge_records() + for badge_record in badge_records: + yield badge_record + + column_badge_record = RDSColumnBadge( + column_rk=self._get_col_key(col), + badge_rk=badge_record.rk + ) + yield column_badge_record + + def _create_atlas_cluster_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_cluster_key()), + ('name', self.cluster), + ('displayName', self.cluster) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.cluster, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def _create_atlas_database_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_database_key()), + ('name', self.database), + ('displayName', self.database) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'cluster', + AtlasCommonTypes.cluster, + self._get_cluster_key() + ) + + entity = AtlasEntity( + typeName=AtlasTableTypes.database, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + return entity + + def _create_atlas_schema_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_schema_key()), + ('name', self.schema), + ('displayName', self.schema) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'cluster', + AtlasCommonTypes.cluster, + self._get_cluster_key() + ) + + entity = AtlasEntity( + typeName=AtlasTableTypes.schema, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + return entity + + def _create_atlas_table_entity(self) -> AtlasEntity: + table_type = 'table' if not self.is_view else 'view' + + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_table_key()), + ('name', self.name), + ('tableType', table_type), + ('description', self.description.text if self.description else ''), + ('displayName', self.name) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'amundsen_schema', + AtlasTableTypes.schema, + self._get_schema_key() + ) + + entity = AtlasEntity( + typeName=AtlasTableTypes.table, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + return entity + + def _create_atlas_column_entity(self, column_metadata: ColumnMetadata) -> AtlasEntity: + qualified_name = column_metadata.COLUMN_KEY_FORMAT.format(db=self.database, + cluster=self.cluster, + schema=self.schema, + tbl=self.name, + col=column_metadata.name) + attrs_mapping = [ + (AtlasCommonParams.qualified_name, qualified_name), + ('name', column_metadata.name or ''), + ('description', column_metadata.description.text if column_metadata.description else ''), + ('type', column_metadata.type), + ('position', column_metadata.sort_order), + ('displayName', column_metadata.name or '') + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'table', + AtlasTableTypes.table, + self._get_table_key() + ) + + entity = AtlasEntity( + typeName=AtlasTableTypes.column, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + return entity + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + for tag in self.tags: + tag_relation = TagMetadata(tag).create_atlas_tag_relation(self._get_table_key()) + yield tag_relation + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_cluster_entity() + yield self._create_atlas_database_entity() + yield self._create_atlas_schema_entity() + yield self._create_atlas_table_entity() + + for col in self.columns: + yield self._create_atlas_column_entity(col) + + if self.tags: + for tag in self.tags: + tag_entity = TagMetadata(tag).create_next_atlas_entity() + if tag_entity: + yield tag_entity + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None diff --git a/databuilder/databuilder/models/table_owner.py b/databuilder/databuilder/models/table_owner.py new file mode 100644 index 0000000000..cd73b9b867 --- /dev/null +++ b/databuilder/databuilder/models/table_owner.py @@ -0,0 +1,30 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Union + +from databuilder.models.owner import Owner +from databuilder.models.table_metadata import TableMetadata + + +class TableOwner(Owner): + """ + Table owner model. + """ + + def __init__(self, + db_name: str, + schema: str, + table_name: str, + owners: Union[List, str], + cluster: str = 'gold', + ) -> None: + self.start_label = TableMetadata.TABLE_NODE_LABEL + self.start_key = f'{db_name}://{cluster}.{schema}/{table_name}' + + Owner.__init__( + self, + start_label=self.start_label, + start_key=self.start_key, + owner_emails=owners, + ) diff --git a/databuilder/databuilder/models/table_serializable.py b/databuilder/databuilder/models/table_serializable.py new file mode 100644 index 0000000000..dbabf31cd1 --- /dev/null +++ b/databuilder/databuilder/models/table_serializable.py @@ -0,0 +1,37 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Union + +from amundsen_rds.models import RDSModel + + +class TableSerializable(object, metaclass=abc.ABCMeta): + """ + A Serializable abstract class asks subclass to implement next record + in rds model instance form so that it can be serialized to CSV file. + + Any model class that needs to be pushed to a relational database should inherit this class. + """ + + def __init__(self) -> None: + pass + + @abc.abstractmethod + def create_next_record(self) -> Union[RDSModel, None]: + """ + Creates rds model instance. + The process that consumes this class takes the output and serializes + the record to the desired relational database. + + :return: a rds model instance or None if no more records to serialize + """ + raise NotImplementedError + + def next_record(self) -> Union[RDSModel, None]: + record = self.create_next_record() + if not record: + return None + + return record diff --git a/databuilder/databuilder/models/table_source.py b/databuilder/databuilder/models/table_source.py new file mode 100644 index 0000000000..723f43cd06 --- /dev/null +++ b/databuilder/databuilder/models/table_source.py @@ -0,0 +1,170 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.table import TableSource as RDSTableSource + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class TableSource(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Hive table source model. + """ + LABEL = 'Source' + KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}/_source' + SOURCE_TABLE_RELATION_TYPE = 'SOURCE_OF' + TABLE_SOURCE_RELATION_TYPE = 'SOURCE' + + def __init__(self, + db_name: str, + schema: str, + table_name: str, + cluster: str, + source: str, + source_type: str = 'github', + ) -> None: + self.db = db_name + self.schema = schema + self.table = table_name + + self.cluster = cluster if cluster else 'gold' + # source is the source file location + self.source = source + self.source_type = source_type + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def get_source_model_key(self) -> str: + return TableSource.KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + tbl=self.table) + + def get_metadata_model_key(self) -> str: + return f'{self.db}://{self.cluster}.{self.schema}/{self.table}' + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create a table source node + :return: + """ + node = GraphNode( + key=self.get_source_model_key(), + label=TableSource.LABEL, + attributes={ + 'source': self.source, + 'source_type': self.source_type + } + ) + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relation map between owner record with original hive table + :return: + """ + relationship = GraphRelationship( + start_label=TableSource.LABEL, + start_key=self.get_source_model_key(), + end_label=TableMetadata.TABLE_NODE_LABEL, + end_key=self.get_metadata_model_key(), + type=TableSource.SOURCE_TABLE_RELATION_TYPE, + reverse_type=TableSource.TABLE_SOURCE_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = RDSTableSource( + rk=self.get_source_model_key(), + source=self.source, + source_type=self.source_type, + table_rk=self.get_metadata_model_key() + ) + yield record + + def _create_atlas_source_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self.get_source_model_key()), + ('name', self.source), + ('source_type', self.source_type), + ('displayName', self.source) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasTableTypes.source, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.table_source, + entityType1=AtlasTableTypes.source, + entityQualifiedName1=self.get_source_model_key(), + entityType2=AtlasTableTypes.table, + entityQualifiedName2=self.get_metadata_model_key(), + attributes={} + ) + + yield relationship + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_source_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def __repr__(self) -> str: + return f'TableSource({self.db!r}, {self.cluster!r}, {self.schema!r}, {self.table!r}, {self.source!r})' diff --git a/databuilder/databuilder/models/table_stats.py b/databuilder/databuilder/models/table_stats.py new file mode 100644 index 0000000000..b5f992c33f --- /dev/null +++ b/databuilder/databuilder/models/table_stats.py @@ -0,0 +1,235 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +from typing import ( + Iterator, Optional, Union, +) + +from amundsen_rds.models import RDSModel +from amundsen_rds.models.column import ColumnStat as RDSColumnStat + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.table_serializable import TableSerializable + +LABEL = 'Stat' +STAT_RESOURCE_RELATION_TYPE = 'STAT_OF' +RESOURCE_STAT_RELATION_TYPE = 'STAT' + + +class TableStats(GraphSerializable, TableSerializable): + """ + Table stats model. + """ + + KEY_FORMAT = '{db}://{cluster}.{schema}' \ + '/{table}/{stat_name}/' + + def __init__(self, + table_name: str, + stat_name: str, + stat_val: str, + is_metric: bool, + db: str = 'hive', + schema: Optional[str] = None, + cluster: str = 'gold', + start_epoch: Optional[str] = None, + end_epoch: Optional[str] = None + ) -> None: + if schema is None: + self.schema, self.table = table_name.split('.') + else: + self.table = table_name + self.schema = schema + self.db = db + self.start_epoch = start_epoch + self.end_epoch = end_epoch + self.cluster = cluster + self.stat_name = stat_name + self.stat_val = str(stat_val) + # metrics are about the table, stats are about the data in a table + # ex: table usage is a metric + self.is_metric = is_metric + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + return None + + def get_table_stat_model_key(self) -> str: + return TableStats.KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + table=self.table, + stat_name=self.stat_name, + is_metric=self.is_metric) + + def get_table_key(self) -> str: + # no cluster, schema info from the input + return TableMetadata.TABLE_KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + tbl=self.table) + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create a table stat node + :return: + """ + node = GraphNode( + key=self.get_table_stat_model_key(), + label=LABEL, + attributes={ + 'stat_val': self.stat_val, + 'stat_type': self.stat_name, + 'start_epoch': self.start_epoch, + 'end_epoch': self.end_epoch, + 'is_metric': self.is_metric, + } + ) + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relation map between table stat record with original table + :return: + """ + relationship = GraphRelationship( + start_key=self.get_table_stat_model_key(), + start_label=LABEL, + end_key=self.get_table_key(), + end_label=TableMetadata.TABLE_NODE_LABEL, + type=STAT_RESOURCE_RELATION_TYPE, + reverse_type=RESOURCE_STAT_RELATION_TYPE, + attributes={} + ) + yield relationship + + +class TableColumnStats(GraphSerializable, TableSerializable): + """ + Hive column stats model. + Each instance represents one row of hive watermark result. + """ + KEY_FORMAT = '{db}://{cluster}.{schema}' \ + '/{table}/{col}/{stat_type}/' + + def __init__(self, + table_name: str, + col_name: str, + stat_name: str, + stat_val: str, + start_epoch: str, + end_epoch: str, + db: str = 'hive', + cluster: str = 'gold', + schema: Optional[str] = None + ) -> None: + if schema is None: + self.schema, self.table = table_name.split('.') + else: + self.table = table_name + self.schema = schema + self.db = db + self.col_name = col_name + self.start_epoch = start_epoch + self.end_epoch = end_epoch + self.cluster = cluster + self.stat_type = stat_name + self.stat_val = str(stat_val) + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def get_column_stat_model_key(self) -> str: + return TableColumnStats.KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + table=self.table, + col=self.col_name, + stat_type=self.stat_type) + + def get_col_key(self) -> str: + # no cluster, schema info from the input + return ColumnMetadata.COLUMN_KEY_FORMAT.format(db=self.db, + cluster=self.cluster, + schema=self.schema, + tbl=self.table, + col=self.col_name) + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create a table stat node + :return: + """ + node = GraphNode( + key=self.get_column_stat_model_key(), + label=LABEL, + attributes={ + 'stat_val': self.stat_val, + 'stat_type': self.stat_type, + 'start_epoch': self.start_epoch, + 'end_epoch': self.end_epoch, + } + ) + yield node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relation map between table stat record with original hive table + :return: + """ + relationship = GraphRelationship( + start_key=self.get_column_stat_model_key(), + start_label=LABEL, + end_key=self.get_col_key(), + end_label=ColumnMetadata.COLUMN_NODE_LABEL, + type=STAT_RESOURCE_RELATION_TYPE, + reverse_type=RESOURCE_STAT_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _create_record_iterator(self) -> Iterator[RDSModel]: + record = RDSColumnStat( + rk=self.get_column_stat_model_key(), + stat_val=self.stat_val, + stat_type=self.stat_type, + start_epoch=self.start_epoch, + end_epoch=self.end_epoch, + column_rk=self.get_col_key() + ) + yield record diff --git a/databuilder/databuilder/models/timestamp/__init__.py b/databuilder/databuilder/models/timestamp/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/timestamp/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/timestamp/timestamp_constants.py b/databuilder/databuilder/models/timestamp/timestamp_constants.py new file mode 100644 index 0000000000..723f463184 --- /dev/null +++ b/databuilder/databuilder/models/timestamp/timestamp_constants.py @@ -0,0 +1,19 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + +NODE_LABEL = 'Timestamp' + +TIMESTAMP_PROPERTY = 'timestamp' +TIMESTAMP_NAME_PROPERTY = 'name' +# This is deprecated property as it's not generic for the Timestamp +DEPRECATED_TIMESTAMP_PROPERTY = 'last_updated_timestamp' + + +LASTUPDATED_RELATION_TYPE = 'LAST_UPDATED_AT' +LASTUPDATED_REVERSE_RELATION_TYPE = 'LAST_UPDATED_TIME_OF' + + +class TimestampName(Enum): + last_updated_timestamp = 1 diff --git a/databuilder/databuilder/models/type_metadata.py b/databuilder/databuilder/models/type_metadata.py new file mode 100644 index 0000000000..e31c16ab54 --- /dev/null +++ b/databuilder/databuilder/models/type_metadata.py @@ -0,0 +1,502 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +from typing import ( + Any, Dict, Iterator, List, Optional, Union, +) + +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import ColumnMetadata, _format_as_list + +LOGGER = logging.getLogger(__name__) + + +class TypeMetadata(abc.ABC, GraphSerializable): + NODE_LABEL = 'Type_Metadata' + COL_TM_RELATION_TYPE = 'TYPE_METADATA' + TM_COL_RELATION_TYPE = 'TYPE_METADATA_OF' + SUBTYPE_RELATION_TYPE = 'SUBTYPE' + INVERSE_SUBTYPE_RELATION_TYPE = 'SUBTYPE_OF' + KIND = 'kind' + NAME = 'name' + DATA_TYPE = 'data_type' + SORT_ORDER = 'sort_order' + + @abc.abstractmethod + def __init__(self, + name: str, + parent: Union[ColumnMetadata, 'TypeMetadata'], + type_str: str, + sort_order: Optional[int] = None) -> None: + self.name = name + self.parent = parent + self.type_str = type_str + # Sort order among TypeMetadata objects with the same parent + self.sort_order = sort_order + + self._description: Optional[DescriptionMetadata] = None + self._badges: Optional[List[Badge]] = None + + self._node_iter = self.create_node_iterator() + self._relation_iter = self.create_relation_iterator() + + def get_description(self) -> Optional[DescriptionMetadata]: + return self._description + + def set_description(self, description: str) -> None: + if isinstance(self.parent, ColumnMetadata): + LOGGER.warning("""Frontend does not currently support setting descriptions for type metadata + objects with a ColumnMetadata parent, since the top level type metadata does + not have its own row in the column table""") + elif isinstance(self.parent, ArrayTypeMetadata): + LOGGER.warning("""Frontend does not currently support setting descriptions for type metadata + objects with an ArrayTypeMetadata parent, since this level in the nesting + hierarchy is not named and therefore is represented by short row that is not + clickable""") + else: + self._description = DescriptionMetadata.create_description_metadata( + source=None, + text=description + ) + + def get_badges(self) -> Optional[List[Badge]]: + return self._badges + + def set_badges(self, badges: Union[List[str], None] = None) -> None: + if isinstance(self.parent, ColumnMetadata): + LOGGER.warning("""Frontend does not currently support setting badges for type metadata + objects with a ColumnMetadata parent, since the top level type metadata does + not have its own row in the column table""") + elif isinstance(self.parent, ArrayTypeMetadata): + LOGGER.warning("""Frontend does not currently support setting badges for type metadata + objects with an ArrayTypeMetadata parent, since this level in the nesting + hierarchy is not named and therefore is represented by short row that is not + clickable""") + else: + formatted_badges = _format_as_list(badges) + self._badges = [Badge(badge, 'type_metadata') for badge in formatted_badges] + + @abc.abstractmethod + def __eq__(self, other: Any) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def is_terminal_type(self) -> bool: + """ + This is used to determine whether any child nodes + should be created for the associated TypeMetadata object. + """ + raise NotImplementedError + + @abc.abstractmethod + def create_node_iterator(self) -> Iterator[GraphNode]: + raise NotImplementedError + + @abc.abstractmethod + def create_relation_iterator(self) -> Iterator[GraphRelationship]: + raise NotImplementedError + + def create_next_node(self) -> Optional[GraphNode]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def key(self) -> str: + if isinstance(self.parent, ColumnMetadata): + return f"{self.parent_key()}/type/{self.name}" + return f"{self.parent_key()}/{self.name}" + + def description_key(self) -> Optional[str]: + if self._description: + description_id = self._description.get_description_id() + return f"{self.key()}/{description_id}" + return None + + def relation_type(self) -> str: + if isinstance(self.parent, ColumnMetadata): + return TypeMetadata.COL_TM_RELATION_TYPE + return TypeMetadata.SUBTYPE_RELATION_TYPE + + def inverse_relation_type(self) -> str: + if isinstance(self.parent, ColumnMetadata): + return TypeMetadata.TM_COL_RELATION_TYPE + return TypeMetadata.INVERSE_SUBTYPE_RELATION_TYPE + + def parent_key(self) -> str: + if isinstance(self.parent, ColumnMetadata): + column_key = self.parent.get_column_key() + assert column_key is not None, f"Column key must be set for {self.parent.name}" + return column_key + return self.parent.key() + + def parent_label(self) -> str: + if isinstance(self.parent, ColumnMetadata): + return ColumnMetadata.COLUMN_NODE_LABEL + return TypeMetadata.NODE_LABEL + + def __repr__(self) -> str: + return f"TypeMetadata({self.type_str!r})" + + +class ArrayTypeMetadata(TypeMetadata): + kind = 'array' + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(ArrayTypeMetadata, self).__init__(*args, **kwargs) + self.array_inner_type: Optional[TypeMetadata] = None + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ArrayTypeMetadata): + return (self.name == other.name and + self.type_str == other.type_str and + self.sort_order == other.sort_order and + self._description == other._description and + self._badges == other._badges and + self.array_inner_type == other.array_inner_type and + self.key() == other.key()) + return False + + def is_terminal_type(self) -> bool: + return not self.array_inner_type + + def create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes: Dict[str, Union[str, None, int]] = { + TypeMetadata.KIND: self.kind, + TypeMetadata.NAME: self.name, + TypeMetadata.DATA_TYPE: self.type_str + } + + if isinstance(self.sort_order, int): + node_attributes[TypeMetadata.SORT_ORDER] = self.sort_order + + yield GraphNode( + key=self.key(), + label=TypeMetadata.NODE_LABEL, + attributes=node_attributes + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_node(description_key) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_nodes = badge_metadata.get_badge_nodes() + for node in badge_nodes: + yield node + + if not self.is_terminal_type(): + assert self.array_inner_type is not None, f"Array inner type must be set for {self.name}" + yield from self.array_inner_type.create_node_iterator() + + def create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=self.parent_label(), + start_key=self.parent_key(), + end_label=TypeMetadata.NODE_LABEL, + end_key=self.key(), + type=self.relation_type(), + reverse_type=self.inverse_relation_type(), + attributes={} + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_relation( + TypeMetadata.NODE_LABEL, + self.key(), + description_key + ) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_relations = badge_metadata.get_badge_relations() + for relation in badge_relations: + yield relation + + if not self.is_terminal_type(): + assert self.array_inner_type is not None, f"Array inner type must be set for {self.name}" + yield from self.array_inner_type.create_relation_iterator() + + +class MapTypeMetadata(TypeMetadata): + kind = 'map' + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(MapTypeMetadata, self).__init__(*args, **kwargs) + self.map_key_type: Optional[TypeMetadata] = None + self.map_value_type: Optional[TypeMetadata] = None + + def __eq__(self, other: Any) -> bool: + if isinstance(other, MapTypeMetadata): + return (self.name == other.name and + self.map_key_type == other.map_key_type and + self.map_value_type == other.map_value_type and + self.type_str == other.type_str and + self.sort_order == other.sort_order and + self._description == other._description and + self._badges == other._badges and + self.key() == other.key()) + return False + + def is_terminal_type(self) -> bool: + return not self.map_key_type or not self.map_value_type + + def create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes: Dict[str, Union[str, None, int]] = { + TypeMetadata.KIND: self.kind, + TypeMetadata.NAME: self.name, + TypeMetadata.DATA_TYPE: self.type_str + } + + if isinstance(self.sort_order, int): + node_attributes[TypeMetadata.SORT_ORDER] = self.sort_order + + yield GraphNode( + key=self.key(), + label=TypeMetadata.NODE_LABEL, + attributes=node_attributes + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_node(description_key) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_nodes = badge_metadata.get_badge_nodes() + for node in badge_nodes: + yield node + + if not self.is_terminal_type(): + assert self.map_key_type is not None, f"Map key type must be set for {self.name}" + assert self.map_value_type is not None, f"Map value type must be set for {self.name}" + yield from self.map_key_type.create_node_iterator() + yield from self.map_value_type.create_node_iterator() + + def create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=self.parent_label(), + start_key=self.parent_key(), + end_label=TypeMetadata.NODE_LABEL, + end_key=self.key(), + type=self.relation_type(), + reverse_type=self.inverse_relation_type(), + attributes={} + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_relation( + TypeMetadata.NODE_LABEL, + self.key(), + description_key + ) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_relations = badge_metadata.get_badge_relations() + for relation in badge_relations: + yield relation + + if not self.is_terminal_type(): + assert self.map_key_type is not None, f"Map key type must be set for {self.name}" + assert self.map_value_type is not None, f"Map value type must be set for {self.name}" + yield from self.map_key_type.create_relation_iterator() + yield from self.map_value_type.create_relation_iterator() + + +class ScalarTypeMetadata(TypeMetadata): + """ + ScalarTypeMetadata represents any non complex type that does not + require special handling. It is also used as the default TypeMetadata + class when a type string cannot be parsed. + """ + kind = 'scalar' + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(ScalarTypeMetadata, self).__init__(*args, **kwargs) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ScalarTypeMetadata): + return (self.name == other.name and + self.type_str == other.type_str and + self.sort_order == other.sort_order and + self._description == other._description and + self._badges == other._badges and + self.key() == other.key()) + return False + + def is_terminal_type(self) -> bool: + return True + + def create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes: Dict[str, Union[str, None, int]] = { + TypeMetadata.KIND: self.kind, + TypeMetadata.NAME: self.name, + TypeMetadata.DATA_TYPE: self.type_str + } + + if isinstance(self.sort_order, int): + node_attributes[TypeMetadata.SORT_ORDER] = self.sort_order + + yield GraphNode( + key=self.key(), + label=TypeMetadata.NODE_LABEL, + attributes=node_attributes + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_node(description_key) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_nodes = badge_metadata.get_badge_nodes() + for node in badge_nodes: + yield node + + def create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=self.parent_label(), + start_key=self.parent_key(), + end_label=TypeMetadata.NODE_LABEL, + end_key=self.key(), + type=self.relation_type(), + reverse_type=self.inverse_relation_type(), + attributes={} + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_relation( + TypeMetadata.NODE_LABEL, + self.key(), + description_key + ) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_relations = badge_metadata.get_badge_relations() + for relation in badge_relations: + yield relation + + +class StructTypeMetadata(TypeMetadata): + kind = 'struct' + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(StructTypeMetadata, self).__init__(*args, **kwargs) + self.struct_items: Optional[Dict[str, TypeMetadata]] = None + + def __eq__(self, other: Any) -> bool: + if isinstance(other, StructTypeMetadata): + return (self.name == other.name and + self.struct_items == other.struct_items and + self.type_str == other.type_str and + self.sort_order == other.sort_order and + self._description == other._description and + self._badges == other._badges and + self.key() == other.key()) + return False + + def is_terminal_type(self) -> bool: + return not self.struct_items + + def create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes: Dict[str, Union[str, None, int]] = { + TypeMetadata.KIND: self.kind, + TypeMetadata.NAME: self.name, + TypeMetadata.DATA_TYPE: self.type_str + } + + if isinstance(self.sort_order, int): + node_attributes[TypeMetadata.SORT_ORDER] = self.sort_order + + yield GraphNode( + key=self.key(), + label=TypeMetadata.NODE_LABEL, + attributes=node_attributes + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_node(description_key) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_nodes = badge_metadata.get_badge_nodes() + for node in badge_nodes: + yield node + + if not self.is_terminal_type(): + assert self.struct_items, f"Struct items must be set for {self.name}" + for name, data_type in self.struct_items.items(): + yield from data_type.create_node_iterator() + + def create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=self.parent_label(), + start_key=self.parent_key(), + end_label=TypeMetadata.NODE_LABEL, + end_key=self.key(), + type=self.relation_type(), + reverse_type=self.inverse_relation_type(), + attributes={} + ) + + if self._description: + description_key = self.description_key() + assert description_key is not None, f"Could not retrieve description key for {self.name}" + yield self._description.get_relation( + TypeMetadata.NODE_LABEL, + self.key(), + description_key + ) + + if self._badges: + badge_metadata = BadgeMetadata(start_label=TypeMetadata.NODE_LABEL, + start_key=self.key(), + badges=self._badges) + badge_relations = badge_metadata.get_badge_relations() + for relation in badge_relations: + yield relation + + if not self.is_terminal_type(): + assert self.struct_items, f"Struct items must be set for {self.name}" + for name, data_type in self.struct_items.items(): + yield from data_type.create_relation_iterator() diff --git a/databuilder/databuilder/models/usage/__init__.py b/databuilder/databuilder/models/usage/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/models/usage/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/models/usage/usage.py b/databuilder/databuilder/models/usage/usage.py new file mode 100644 index 0000000000..5cbe81f11e --- /dev/null +++ b/databuilder/databuilder/models/usage/usage.py @@ -0,0 +1,207 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +from typing import ( + Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import ( + AtlasCommonParams, AtlasCommonTypes, AtlasTableKey, +) +from amundsen_rds.models import RDSModel +from amundsen_rds.models.dashboard import DashboardUsage as RDSDashboardUsage +from amundsen_rds.models.table import TableUsage as RDSTableUsage +from amundsen_rds.models.user import User as RDSUser + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_serializable import TableSerializable +from databuilder.models.usage.usage_constants import ( + READ_RELATION_COUNT_PROPERTY, READ_RELATION_TYPE, READ_REVERSE_RELATION_TYPE, +) +from databuilder.models.user import User +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasRelationshipTypes, AtlasSerializedEntityOperation + + +class Usage(GraphSerializable, TableSerializable, AtlasSerializable): + LABELS_PERMITTED_TO_HAVE_USAGE = ['Table', 'Dashboard', 'Feature'] + + def __init__(self, + start_label: str, + start_key: str, + user_email: str, + read_count: int = 1) -> None: + + if start_label not in Usage.LABELS_PERMITTED_TO_HAVE_USAGE: + raise Exception(f'usage for {start_label} is not supported') + + self.start_label = start_label + self.start_key = start_key + self.user_email = user_email.strip().lower() + self.read_count = int(read_count) + + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + self._atlas_relation_iterator = self._create_atlas_relation_iterator() + + def __repr__(self) -> str: + return f"Usage(start_label={self.start_label!r}, start_key={self.start_key!r}, " \ + f"user_email={self.user_email!r}, read_count={self.read_count!r})" + + def create_next_node(self) -> Optional[GraphNode]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Optional[RDSModel]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_node_iterator(self) -> Iterator[GraphNode]: + if self.user_email: + yield GraphNode( + key=User.get_user_model_key(email=self.user_email), + label=User.USER_NODE_LABEL, + attributes={ + User.USER_NODE_EMAIL: self.user_email, + } + ) + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + yield GraphRelationship( + start_label=self.start_label, + start_key=self.start_key, + end_label=User.USER_NODE_LABEL, + end_key=User.get_user_model_key(email=self.user_email), + type=READ_REVERSE_RELATION_TYPE, + reverse_type=READ_RELATION_TYPE, + attributes={ + READ_RELATION_COUNT_PROPERTY: self.read_count, + } + ) + + def _create_record_iterator(self) -> Iterator[RDSModel]: + if self.user_email: + yield RDSUser( + rk=User.get_user_model_key(email=self.user_email), + email=self.user_email + ) + + if self.start_label == TableMetadata.TABLE_NODE_LABEL: + yield RDSTableUsage(user_rk=User.get_user_model_key(email=self.user_email), + table_rk=self.start_key, + read_count=self.read_count) + elif self.start_label == DashboardMetadata.DASHBOARD_NODE_LABEL: + yield RDSDashboardUsage( + user_rk=User.get_user_model_key(email=self.user_email), + dashboard_rk=self.start_key, + read_count=self.read_count, + ) + else: + raise Exception(f'{self.start_label} usage is not table serializable') + + def _get_user_key(self) -> str: + return User.get_user_model_key(email=self.user_email) + + def _get_reader_key(self) -> str: + return f'{self.start_key}/_reader/{self._get_user_key()}' + + def _get_entity_type(self) -> str: + if self.start_label == 'Table': + entity_type = AtlasTableKey(self.start_key).entity_type + else: + entity_type = self.start_label + + return entity_type + + def _create_atlas_user_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_user_key()), + ('email', self._get_user_key()) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.user, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + return entity + + def _create_atlas_reader_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self._get_reader_key()), + ('count', self.read_count), + ('entityUri', self.start_label) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.reader, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + return entity + + def _create_atlas_reader_dataset_relation(self) -> AtlasRelationship: + relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.referenceable_reader, + entityType1=self._get_entity_type(), + entityQualifiedName1=self.start_key, + entityType2=AtlasCommonTypes.reader, + entityQualifiedName2=self._get_reader_key(), + attributes=dict(count=self.read_count) + ) + return relationship + + def _create_atlas_user_reader_relation(self) -> AtlasRelationship: + relationship = AtlasRelationship( + relationshipType=AtlasRelationshipTypes.reader_user, + entityType1=AtlasCommonTypes.reader, + entityQualifiedName1=self._get_reader_key(), + entityType2=AtlasCommonTypes.user, + entityQualifiedName2=self._get_user_key(), + attributes={} + ) + return relationship + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_user_entity() + yield self._create_atlas_reader_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) # type: ignore + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._atlas_relation_iterator) # type: ignore + except StopIteration: + return None + + def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]: + yield self._create_atlas_reader_dataset_relation() + yield self._create_atlas_user_reader_relation() diff --git a/databuilder/databuilder/models/usage/usage_constants.py b/databuilder/databuilder/models/usage/usage_constants.py new file mode 100644 index 0000000000..f9dd962bad --- /dev/null +++ b/databuilder/databuilder/models/usage/usage_constants.py @@ -0,0 +1,7 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +READ_RELATION_TYPE = 'READ' +READ_REVERSE_RELATION_TYPE = 'READ_BY' + +READ_RELATION_COUNT_PROPERTY = 'read_count' diff --git a/databuilder/databuilder/models/user.py b/databuilder/databuilder/models/user.py new file mode 100644 index 0000000000..0f9b9bcab9 --- /dev/null +++ b/databuilder/databuilder/models/user.py @@ -0,0 +1,279 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +from typing import ( + Any, Iterator, Optional, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasCommonTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.user import User as RDSUser + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import get_entity_attrs +from databuilder.utils.atlas import AtlasSerializedEntityOperation + + +class User(GraphSerializable, TableSerializable, AtlasSerializable): + """ + User model. This model doesn't define any relationship. + """ + USER_NODE_LABEL = 'User' + USER_NODE_KEY_FORMAT = '{email}' + USER_NODE_EMAIL = 'email' + USER_NODE_FIRST_NAME = 'first_name' + USER_NODE_LAST_NAME = 'last_name' + USER_NODE_FULL_NAME = 'full_name' + USER_NODE_GITHUB_NAME = 'github_username' + USER_NODE_TEAM = 'team_name' + USER_NODE_EMPLOYEE_TYPE = 'employee_type' + USER_NODE_MANAGER_EMAIL = 'manager_email' + USER_NODE_SLACK_ID = 'slack_id' + USER_NODE_IS_ACTIVE = 'is_active' # bool value needs to be unquoted when publish to neo4j + USER_NODE_PROFILE_URL = 'profile_url' + USER_NODE_UPDATED_AT = 'updated_at' + USER_NODE_ROLE_NAME = 'role_name' + + USER_MANAGER_RELATION_TYPE = 'MANAGE_BY' + MANAGER_USER_RELATION_TYPE = 'MANAGE' + + def __init__(self, + email: str, + first_name: str = '', + last_name: str = '', + full_name: str = '', + github_username: str = '', + team_name: str = '', + employee_type: str = '', + manager_email: str = '', + slack_id: str = '', + is_active: bool = True, + profile_url: str = '', + updated_at: int = 0, + role_name: str = '', + do_not_update_empty_attribute: bool = False, + **kwargs: Any + ) -> None: + """ + This class models user node for Amundsen people. + + :param first_name: + :param last_name: + :param full_name: + :param email: + :param github_username: + :param team_name: + :param employee_type: + :param manager_email: + :param is_active: + :param profile_url: + :param updated_at: everytime we update the node, we will push the timestamp. + then we will have a cron job to update the ex-employee nodes based on + the case if this timestamp hasn't been updated for two weeks. + :param role_name: the role_name of the user (e.g swe) + :param do_not_update_empty_attribute: If False, all empty or not defined params will be overwritten with + empty string. + :param kwargs: Any K/V attributes we want to update the + """ + self.first_name = first_name + self.last_name = last_name + self.full_name = full_name + + self.email = email + self.github_username = github_username + # todo: team will be a separate node once Amundsen People supports team + self.team_name = team_name + self.manager_email = manager_email + self.employee_type = employee_type + # this attr not available in team service, either update team service, update with FE + self.slack_id = slack_id + self.is_active = is_active + self.profile_url = profile_url + self.updated_at = updated_at + self.role_name = role_name + self.do_not_update_empty_attribute = do_not_update_empty_attribute + self.attrs = None + if kwargs: + self.attrs = copy.deepcopy(kwargs) + + self._node_iter = self._create_node_iterator() + self._rel_iter = self._create_relation_iterator() + self._record_iter = self._create_record_iterator() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def create_next_node(self) -> Optional[GraphNode]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Optional[GraphRelationship]: + """ + :return: + """ + try: + return next(self._rel_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + @classmethod + def get_user_model_key(cls, + email: Optional[str] = None + ) -> str: + if not email: + return '' + return User.USER_NODE_KEY_FORMAT.format(email=email) + + def get_user_node(self) -> GraphNode: + node_attributes = { + User.USER_NODE_EMAIL: self.email, + User.USER_NODE_IS_ACTIVE: self.is_active, + User.USER_NODE_PROFILE_URL: self.profile_url or '', + User.USER_NODE_FIRST_NAME: self.first_name or '', + User.USER_NODE_LAST_NAME: self.last_name or '', + User.USER_NODE_FULL_NAME: self.full_name or '', + User.USER_NODE_GITHUB_NAME: self.github_username or '', + User.USER_NODE_TEAM: self.team_name or '', + User.USER_NODE_EMPLOYEE_TYPE: self.employee_type or '', + User.USER_NODE_SLACK_ID: self.slack_id or '', + User.USER_NODE_ROLE_NAME: self.role_name or '' + } + + if self.updated_at: + node_attributes[User.USER_NODE_UPDATED_AT] = self.updated_at + elif not self.do_not_update_empty_attribute: + node_attributes[User.USER_NODE_UPDATED_AT] = 0 + + if self.attrs: + for k, v in self.attrs.items(): + if k not in node_attributes: + node_attributes[k] = v + + if self.do_not_update_empty_attribute: + for k, v in list(node_attributes.items()): + if not v: + del node_attributes[k] + + node = GraphNode( + key=User.get_user_model_key(email=self.email), + label=User.USER_NODE_LABEL, + attributes=node_attributes + ) + + return node + + def get_user_record(self) -> RDSModel: + record_attr_map = { + RDSUser.email: self.email, + RDSUser.is_active: self.is_active, + RDSUser.profile_url: self.profile_url or '', + RDSUser.first_name: self.first_name or '', + RDSUser.last_name: self.last_name or '', + RDSUser.full_name: self.full_name or '', + RDSUser.github_username: self.github_username or '', + RDSUser.team_name: self.team_name or '', + RDSUser.employee_type: self.employee_type or '', + RDSUser.slack_id: self.slack_id or '', + RDSUser.role_name: self.role_name or '', + RDSUser.updated_at: self.updated_at or 0 + } + + record = RDSUser(rk=User.get_user_model_key(email=self.email)) + # set value for attributes of user record if the value is not empty + # or the flag allows to update empty values + for attr, value in record_attr_map.items(): + if value or not self.do_not_update_empty_attribute: + record.__setattr__(attr.key, value) + + if self.manager_email: + record.manager_rk = self.get_user_model_key(email=self.manager_email) + + return record + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create an user node + :return: + """ + user_node = self.get_user_node() + yield user_node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + if self.manager_email: + # only create the relation if the manager exists + relationship = GraphRelationship( + start_key=User.get_user_model_key(email=self.email), + start_label=User.USER_NODE_LABEL, + end_label=User.USER_NODE_LABEL, + end_key=self.get_user_model_key(email=self.manager_email), + type=User.USER_MANAGER_RELATION_TYPE, + reverse_type=User.MANAGER_USER_RELATION_TYPE, + attributes={} + ) + yield relationship + + def _create_record_iterator(self) -> Iterator[RDSModel]: + user_record = self.get_user_record() + yield user_record + + def _create_atlas_user_entity(self) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, User.get_user_model_key(email=self.email)), + ('email', self.email), + ('first_name', self.first_name), + ('last_name', self.last_name), + ('full_name', self.full_name), + ('github_username', self.github_username), + ('team_name', self.team_name), + ('employee_type', self.employee_type), + ('manager_email', self.manager_email), + ('slack_id', self.slack_id), + ('is_active', self.is_active), + ('profile_url', self.profile_url), + ('updated_at', self.updated_at), + ('role_name', self.role_name), + ('displayName', self.email) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + entity = AtlasEntity( + typeName=AtlasCommonTypes.user, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=None + ) + + return entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + yield self._create_atlas_user_entity() + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None + + def __repr__(self) -> str: + return f'User({self.first_name!r}, {self.last_name!r}, {self.full_name!r}, {self.email!r}, ' \ + f'{self.github_username!r}, {self.team_name!r}, {self.slack_id!r}, {self.manager_email!r}, ' \ + f'{self.employee_type!r}, {self.is_active!r}, {self.profile_url!r}, {self.updated_at!r}, ' \ + f'{self.role_name!r})' diff --git a/databuilder/databuilder/models/user_elasticsearch_document.py b/databuilder/databuilder/models/user_elasticsearch_document.py new file mode 100644 index 0000000000..65b35cea27 --- /dev/null +++ b/databuilder/databuilder/models/user_elasticsearch_document.py @@ -0,0 +1,41 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from databuilder.models.elasticsearch_document import ElasticsearchDocument + + +class UserESDocument(ElasticsearchDocument): + """ + Schema for the Search index document for user + """ + + def __init__(self, + email: str, + first_name: str, + last_name: str, + full_name: str, + github_username: str, + team_name: str, + employee_type: str, + manager_email: str, + slack_id: str, + role_name: str, + is_active: bool, + total_read: int, + total_own: int, + total_follow: int, + ) -> None: + self.email = email + self.first_name = first_name + self.last_name = last_name + self.full_name = full_name + self.github_username = github_username + self.team_name = team_name + self.employee_type = employee_type + self.manager_email = manager_email + self.slack_id = slack_id + self.role_name = role_name + self.is_active = is_active + self.total_read = total_read + self.total_own = total_own + self.total_follow = total_follow diff --git a/databuilder/databuilder/models/watermark.py b/databuilder/databuilder/models/watermark.py new file mode 100644 index 0000000000..baa3375b7d --- /dev/null +++ b/databuilder/databuilder/models/watermark.py @@ -0,0 +1,186 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import ( + Iterator, List, Tuple, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes +from amundsen_rds.models import RDSModel +from amundsen_rds.models.table import TableWatermark as RDSTableWatermark + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers.atlas_serializer import ( + add_entity_relationship, get_entity_attrs, get_entity_relationships, +) +from databuilder.utils.atlas import AtlasSerializedEntityOperation + + +class Watermark(GraphSerializable, TableSerializable, AtlasSerializable): + """ + Table watermark result model. + Each instance represents one row of table watermark result. + """ + LABEL = 'Watermark' + KEY_FORMAT = '{database}://{cluster}.{schema}' \ + '/{table}/{part_type}/' + WATERMARK_TABLE_RELATION_TYPE = 'BELONG_TO_TABLE' + TABLE_WATERMARK_RELATION_TYPE = 'WATERMARK' + + def __init__(self, + create_time: str, + database: str, + schema: str, + table_name: str, + part_name: str, + part_type: str = 'high_watermark', + cluster: str = 'gold', + ) -> None: + self.create_time = create_time + self.database = database + self.schema = schema + self.table = table_name + self.parts: List[Tuple[str, str]] = [] + + if '=' not in part_name: + raise Exception('Only partition table has high watermark') + + # currently we don't consider nested partitions + idx = part_name.find('=') + name, value = part_name[:idx], part_name[idx + 1:] + self.parts = [(name, value)] + self.part_type = part_type + self.cluster = cluster + self._node_iter = self._create_node_iterator() + self._relation_iter = self._create_relation_iterator() + self._record_iter = self._create_next_record() + self._atlas_entity_iterator = self._create_next_atlas_entity() + + def __repr__(self) -> str: + return f"Watermark(create_time={str(self.create_time)!r}, database={self.database!r}, " \ + f"schema={self.schema!r}, table={self.table!r}, parts={self.parts!r}, " \ + f"cluster={self.cluster!r}, part_type={self.part_type!r})" + + def create_next_node(self) -> Union[GraphNode, None]: + # return the string representation of the data + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def get_watermark_model_key(self) -> str: + return Watermark.KEY_FORMAT.format(database=self.database, + cluster=self.cluster, + schema=self.schema, + table=self.table, + part_type=self.part_type) + + def get_metadata_model_key(self) -> str: + return f'{self.database}://{self.cluster}.{self.schema}/{self.table}' + + def _create_node_iterator(self) -> Iterator[GraphNode]: + """ + Create watermark nodes + :return: + """ + for part in self.parts: + part_node = GraphNode( + key=self.get_watermark_model_key(), + label=Watermark.LABEL, + attributes={ + 'partition_key': part[0], + 'partition_value': part[1], + 'create_time': self.create_time + } + ) + yield part_node + + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + """ + Create relation map between watermark record with original table + :return: + """ + relation = GraphRelationship( + start_key=self.get_watermark_model_key(), + start_label=Watermark.LABEL, + end_key=self.get_metadata_model_key(), + end_label='Table', + type=Watermark.WATERMARK_TABLE_RELATION_TYPE, + reverse_type=Watermark.TABLE_WATERMARK_RELATION_TYPE, + attributes={} + ) + yield relation + + def _create_next_record(self) -> Iterator[RDSModel]: + """ + Create watermark records + """ + for part in self.parts: + part_record = RDSTableWatermark( + rk=self.get_watermark_model_key(), + partition_key=part[0], + partition_value=part[1], + create_time=self.create_time, + table_rk=self.get_metadata_model_key() + ) + yield part_record + + def _create_atlas_partition_entity(self, spec: Tuple[str, str]) -> AtlasEntity: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self.get_watermark_model_key()), + ('name', spec[1]), + ('displayName', spec[1]), + ('key', spec[0]), + ('create_time', self.create_time) + ] + + entity_attrs = get_entity_attrs(attrs_mapping) + + relationship_list = [] # type: ignore + + add_entity_relationship( + relationship_list, + 'table', + AtlasTableTypes.table, + self.get_metadata_model_key() + ) + + entity = AtlasEntity( + typeName=AtlasTableTypes.watermark, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=entity_attrs, + relationships=get_entity_relationships(relationship_list) + ) + + return entity + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + pass + + def _create_next_atlas_entity(self) -> Iterator[AtlasEntity]: + for part in self.parts: + yield self._create_atlas_partition_entity(part) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._atlas_entity_iterator) + except StopIteration: + return None diff --git a/databuilder/databuilder/publisher/__init__.py b/databuilder/databuilder/publisher/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/databuilder/publisher/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/databuilder/publisher/atlas_csv_publisher.py b/databuilder/databuilder/publisher/atlas_csv_publisher.py new file mode 100644 index 0000000000..894df5769f --- /dev/null +++ b/databuilder/databuilder/publisher/atlas_csv_publisher.py @@ -0,0 +1,363 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import logging +from os import listdir +from os.path import isfile, join +from typing import ( + Any, Dict, Iterator, List, Tuple, +) + +import pandas +from amundsen_common.utils.atlas import AtlasCommonParams, AtlasCommonTypes +from apache_atlas.exceptions import AtlasServiceException +from apache_atlas.model.glossary import ( + AtlasGlossary, AtlasGlossaryHeader, AtlasGlossaryTerm, +) +from apache_atlas.model.instance import ( + AtlasEntitiesWithExtInfo, AtlasEntity, AtlasObjectId, AtlasRelatedObjectId, +) +from apache_atlas.model.relationship import AtlasRelationship +from apache_atlas.model.typedef import AtlasClassificationDef, AtlasTypesDef +from pyhocon import ConfigTree + +from databuilder.publisher.base_publisher import Publisher +from databuilder.types.atlas import AtlasEntityInitializer +from databuilder.utils.atlas import ( + AtlasRelationshipTypes, AtlasSerializedEntityFields, AtlasSerializedEntityOperation, + AtlasSerializedRelationshipFields, +) + +LOGGER = logging.getLogger(__name__) + + +class AtlasCSVPublisher(Publisher): + # atlas client + ATLAS_CLIENT = 'atlas_client' + # A directory that contains CSV files for entities + ENTITY_DIR_PATH = 'entity_files_directory' + # A directory that contains CSV files for relationships + RELATIONSHIP_DIR_PATH = 'relationship_files_directory' + # atlas create entity batch size + ATLAS_ENTITY_CREATE_BATCH_SIZE = 'batch_size' + # whether entity types should be registered before data is synced to Atlas + REGISTER_ENTITY_TYPES = 'register_entity_types' + + def __init__(self) -> None: + super().__init__() + + def init(self, conf: ConfigTree) -> None: + self._entity_files = self._list_files(conf, AtlasCSVPublisher.ENTITY_DIR_PATH) + self._relationship_files = self._list_files(conf, AtlasCSVPublisher.RELATIONSHIP_DIR_PATH) + self._config = conf + self._atlas_client = self._config.get(AtlasCSVPublisher.ATLAS_CLIENT) + self._register_entity_types = self._config.get_bool(AtlasCSVPublisher.REGISTER_ENTITY_TYPES, True) + + if self._register_entity_types: + LOGGER.info('Registering Atlas Entity Types.') + + try: + init = AtlasEntityInitializer(self._atlas_client) + init.create_required_entities() + + LOGGER.info('Registered Atlas Entity Types.') + except Exception: + LOGGER.error('Failed to register Atlas Entity Types.', exc_info=True) + + def _list_files(self, conf: ConfigTree, path_key: str) -> List[str]: + """ + List files from directory + :param conf: + :param path_key: + :return: List of file paths + """ + if path_key not in conf: + return [] + + path = conf.get_string(path_key) + return sorted(join(path, f) for f in listdir(path) if isfile(join(path, f))) + + def publish_impl(self) -> None: + """ + Publishes Entities first and then Relations + :return: + """ + LOGGER.info('Creating entities using Entity files: %s', self._entity_files) + for entity_file in self._entity_files: + entities_to_create, entities_to_update, \ + glossary_terms_create, classifications_create = self._create_entity_instances(entity_file=entity_file) + self._sync_entities_to_atlas(entities_to_create) + self._update_entities(entities_to_update) + self._create_glossary_terms(glossary_terms_create) + self._create_classifications(classifications_create) + + LOGGER.info('Creating relations using relation files: %s', self._relationship_files) + for relation_file in self._relationship_files: + self._create_relations(relation_file=relation_file) + + def _update_entities(self, entities_to_update: List[AtlasEntity]) -> None: + """ + Go over the entities list , create atlas relationships instances and sync them with atlas + :param entities_to_update: + :return: + """ + for entity_to_update in entities_to_update: + existing_entity = self._atlas_client.entity.get_entity_by_attribute( + entity_to_update.attributes[AtlasCommonParams.type_name], + [(AtlasCommonParams.qualified_name, entity_to_update.attributes[AtlasCommonParams.qualified_name])], + ) + existing_entity.entity.attributes.update(entity_to_update.attributes) + try: + self._atlas_client.entity.update_entity(existing_entity) + except AtlasServiceException: + LOGGER.error('Fail to update entity', exc_info=True) + + def _create_relations(self, relation_file: str) -> None: + """ + Go over the relation file, create atlas relationships instances and sync them with atlas + :param relation_file: + :return: + """ + + with open(relation_file, encoding='utf8') as relation_csv: + for relation_record in pandas.read_csv(relation_csv, na_filter=False).to_dict(orient='records'): + if relation_record[AtlasSerializedRelationshipFields.relation_type] == AtlasRelationshipTypes.tag: + self._assign_glossary_term(relation_record) + continue + elif relation_record[AtlasSerializedRelationshipFields.relation_type] == AtlasRelationshipTypes.badge: + self._assign_classification(relation_record) + continue + + relation = self._create_relation(relation_record) + try: + self._atlas_client.relationship.create_relationship(relation) + except AtlasServiceException: + LOGGER.error('Fail to create atlas relationship', exc_info=True) + except Exception as e: + LOGGER.error(e) + + def _render_unique_attributes(self, entity_type: str, qualified_name: str) -> Dict[Any, Any]: + """ + Render uniqueAttributes dict, this struct is needed to identify AtlasObjects + :param entity_type: + :param qualified_name: + :return: rendered uniqueAttributes dict + """ + return { + AtlasCommonParams.type_name: entity_type, + AtlasCommonParams.unique_attributes: { + AtlasCommonParams.qualified_name: qualified_name, + }, + } + + def _get_atlas_related_object_id_by_qn(self, entity_type: str, qn: str) -> AtlasRelatedObjectId: + return AtlasRelatedObjectId(attrs=self._render_unique_attributes(entity_type, qn)) + + def _get_atlas_object_id_by_qn(self, entity_type: str, qn: str) -> AtlasObjectId: + return AtlasObjectId(attrs=self._render_unique_attributes(entity_type, qn)) + + def _create_relation(self, relation_dict: Dict[str, str]) -> AtlasRelationship: + """ + Go over the relation dictionary file and create atlas relationships instances + :param relation_dict: + :return: + """ + + relation = AtlasRelationship( + {AtlasCommonParams.type_name: relation_dict[AtlasSerializedRelationshipFields.relation_type]}, + ) + relation.end1 = self._get_atlas_object_id_by_qn( + relation_dict[AtlasSerializedRelationshipFields.entity_type_1], + relation_dict[AtlasSerializedRelationshipFields.qualified_name_1], + ) + relation.end2 = self._get_atlas_object_id_by_qn( + relation_dict[AtlasSerializedRelationshipFields.entity_type_2], + relation_dict[AtlasSerializedRelationshipFields.qualified_name_2], + ) + + return relation + + def _create_entity_instances(self, entity_file: str) -> Tuple[List[AtlasEntity], List[AtlasEntity], + List[Dict], List[Dict]]: + """ + Go over the entities file and try creating instances + :param entity_file: + :return: + """ + entities_to_create = [] + entities_to_update = [] + glossary_terms_to_create = [] + classifications_to_create = [] + with open(entity_file, encoding='utf8') as entity_csv: + for entity_record in pandas.read_csv(entity_csv, na_filter=False).to_dict(orient='records'): + if entity_record[AtlasSerializedEntityFields.type_name] == AtlasCommonTypes.tag: + glossary_terms_to_create.append(entity_record) + continue + + if entity_record[AtlasSerializedEntityFields.type_name] == AtlasCommonTypes.badge: + classifications_to_create.append(entity_record) + continue + + if entity_record[AtlasSerializedEntityFields.operation] == AtlasSerializedEntityOperation.CREATE: + entities_to_create.append(self._create_entity_from_dict(entity_record)) + if entity_record[AtlasSerializedEntityFields.operation] == AtlasSerializedEntityOperation.UPDATE: + entities_to_update.append(self._create_entity_from_dict(entity_record)) + return entities_to_create, entities_to_update, glossary_terms_to_create, classifications_to_create + + def _extract_entity_relations_details(self, relation_details: str) -> Iterator[Tuple]: + """ + Generate relation details from relation_attr#related_entity_type#related_qualified_name + """ + relations = relation_details.split(AtlasSerializedEntityFields.relationships_separator) + for relation in relations: + relation_split = relation.split(AtlasSerializedEntityFields.relationships_kv_separator) + yield relation_split[0], relation_split[1], relation_split[2] + + def _create_entity_from_dict(self, entity_dict: Dict) -> AtlasEntity: + """ + Create atlas entity instance from dict + :param entity_dict: + :return: AtlasEntity + """ + type_name = {AtlasCommonParams.type_name: entity_dict[AtlasCommonParams.type_name]} + entity = AtlasEntity(type_name) + entity.attributes = entity_dict + relationships = entity_dict.get(AtlasSerializedEntityFields.relationships) + if relationships: + relations = {} + for relation_attr, rel_type, rel_qn in self._extract_entity_relations_details(relationships): + related_obj = self._get_atlas_related_object_id_by_qn(rel_type, rel_qn) + relations[relation_attr] = related_obj + entity.relationshipAttributes = relations + return entity + + def _chunks(self, lst: List) -> Iterator: + """ + Yield successive n-sized chunks from lst. + :param lst: + :return: chunks generator + """ + n = self._config.get_int(AtlasCSVPublisher.ATLAS_ENTITY_CREATE_BATCH_SIZE) + for i in range(0, len(lst), n): + yield lst[i:i + n] + + def _sync_entities_to_atlas(self, entities: List[AtlasEntity]) -> None: + """ + Sync entities instances with atlas + :param entities: list of entities + :return: + """ + entities_chunks = self._chunks(entities) + for entity_chunk in entities_chunks: + LOGGER.info(f'Syncing chunk of {len(entity_chunk)} entities with atlas') + chunk = AtlasEntitiesWithExtInfo() + chunk.entities = entity_chunk + try: + self._atlas_client.entity.create_entities(chunk) + except AtlasServiceException: + LOGGER.error('Error during entity syncing', exc_info=True) + + def _create_glossary_terms(self, glossary_terms: List[Dict]) -> None: + for glossary_term_spec in glossary_terms: + glossary_name = glossary_term_spec.get('glossary') + term_name = glossary_term_spec.get('term') + + glossary_def = AtlasGlossary({'name': glossary_name, 'shortDescription': ''}) + + try: + glossary = self._atlas_client.glossary.create_glossary(glossary_def) + except AtlasServiceException: + LOGGER.info(f'Glossary: {glossary_name} already exists.') + glossary = next(filter(lambda x: x.get('name') == glossary_name, + self._atlas_client.glossary.get_all_glossaries())) + + glossary_guid = glossary['guid'] + glossary_def = AtlasGlossaryHeader({'glossaryGuid': glossary_guid}) + term_def = AtlasGlossaryTerm({'name': term_name, 'anchor': glossary_def}) + + try: + self._atlas_client.glossary.create_glossary_term(term_def) + except AtlasServiceException: + LOGGER.info(f'Glossary Term: {term_name} already exists.') + + def _assign_glossary_term(self, relationship_spec: Dict) -> None: + _glossary_name, _term_name = relationship_spec[AtlasSerializedRelationshipFields.qualified_name_2].split(',') + + glossary_name = _glossary_name.split('=')[1] + term_name = _term_name.split('=')[1] + + entity_type = relationship_spec[AtlasSerializedRelationshipFields.entity_type_1] + entity_qn = relationship_spec[AtlasSerializedRelationshipFields.qualified_name_1] + + glossary = next(filter(lambda g: g.get('name') == glossary_name, + self._atlas_client.glossary.get_all_glossaries())) + + glossary_guid = glossary[AtlasCommonParams.guid] + + term = next(filter(lambda t: t.get('name') == term_name, + self._atlas_client.glossary.get_glossary_terms(glossary_guid))) + + entity = self._atlas_client.entity.get_entity_by_attribute(entity_type, uniq_attributes=[ + (AtlasCommonParams.qualified_name, entity_qn)]) + + entity_guid = entity.entity.guid + + e = AtlasRelatedObjectId({AtlasCommonParams.guid: entity_guid, AtlasCommonParams.type_name: entity_type}) + + try: + self._atlas_client.glossary.assign_term_to_entities(term[AtlasCommonParams.guid], [e]) + except Exception: + LOGGER.error('Error assigning terms to entities.', exc_info=True) + + def _render_super_type_from_dict(self, classification_spec: Dict) -> AtlasClassificationDef: + return self._render_classification(classification_spec, True) + + def _render_sub_type_from_dict(self, classification_spec: Dict) -> AtlasClassificationDef: + return self._render_classification(classification_spec, False) + + def _render_classification(self, classification_spec: Dict, super_type: bool) -> AtlasClassificationDef: + name = classification_spec.get('category') if super_type else classification_spec.get('name') + sub_types = [classification_spec.get('category')] if not super_type else [] + + result = AtlasClassificationDef(attrs=dict(name=name, + attributeDefs=[], + subTypes=sub_types, + superTypes=[], + entityTypes=[])) + + return result + + def _create_classifications(self, classifications: List[Dict]) -> None: + _st = set() + super_types = [self._render_super_type_from_dict(s) for s in classifications if s['category'] not in _st and + _st.add(s['category'])] # type: ignore + super_types_chunks = self._chunks(super_types) + + sub_types = [self._render_sub_type_from_dict(s) for s in classifications] + sub_types_chunks = self._chunks(sub_types) + + for chunks in [super_types_chunks, sub_types_chunks]: + for chunk in chunks: + LOGGER.info(f'Syncing chunk of {len(chunk)} classifications with atlas') + try: + types = AtlasTypesDef(attrs=dict(classificationDefs=chunk)) + + self._atlas_client.typedef.create_atlas_typedefs(types) + except AtlasServiceException: + LOGGER.error('Error during classification syncing', exc_info=True) + + def _assign_classification(self, relationship_spec: Dict) -> None: + classification_qn = relationship_spec[AtlasSerializedRelationshipFields.qualified_name_2] + + entity_type = relationship_spec[AtlasSerializedRelationshipFields.entity_type_1] + entity_qn = relationship_spec[AtlasSerializedRelationshipFields.qualified_name_1] + + try: + self._atlas_client.entity.add_classifications_by_type(entity_type, + uniq_attributes=[(AtlasCommonParams.qualified_name, + entity_qn)], + classifications=[classification_qn]) + except Exception: + LOGGER.error('Error during classification assingment.', exc_info=True) + + def get_scope(self) -> str: + return 'publisher.atlas_csv_publisher' diff --git a/databuilder/databuilder/publisher/base_publisher.py b/databuilder/databuilder/publisher/base_publisher.py new file mode 100644 index 0000000000..13c84a080d --- /dev/null +++ b/databuilder/databuilder/publisher/base_publisher.py @@ -0,0 +1,76 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import List + +from pyhocon import ConfigTree + +from databuilder import Scoped +from databuilder.callback import call_back +from databuilder.callback.call_back import Callback + + +class Publisher(Scoped): + """ + A Publisher that writes dataset (not a record) in Atomic manner, + if possible. + (Either success or fail, no partial state) + Use case: If you want to use neo4j import util or Load CSV util, + that takes CSV file to load database, you need to first create CSV file. + CSV file holds number of records, and loader can writes multiple records + to it. Once loader finishes writing CSV file, you have complete CSV file, + ready to publish to Neo4j. Publisher can take the location of CSV file, + and push to Neo4j. + + """ + + def __init__(self) -> None: + self.call_backs: List[Callback] = [] + + @abc.abstractmethod + def init(self, conf: ConfigTree) -> None: + pass + + def publish(self) -> None: + try: + self.publish_impl() + except Exception as e: + call_back.notify_callbacks(self.call_backs, is_success=False) + raise e + call_back.notify_callbacks(self.call_backs, is_success=True) + + @abc.abstractmethod + def publish_impl(self) -> None: + """ + An implementation of publish method. Subclass of publisher is expected to write publish logic by overriding + this method + :return: None + """ + pass + + def register_call_back(self, callback: Callback) -> None: + """ + Register any callback method that needs to be notified when publisher is either able to successfully publish + or failed to publish + :param callback: + :return: None + """ + self.call_backs.append(callback) + + def get_scope(self) -> str: + return 'publisher' + + +class NoopPublisher(Publisher): + def __init__(self) -> None: + super(NoopPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + pass + + def publish_impl(self) -> None: + pass + + def get_scope(self) -> str: + return 'publisher.noop' diff --git a/databuilder/databuilder/publisher/elasticsearch_constants.py b/databuilder/databuilder/publisher/elasticsearch_constants.py new file mode 100644 index 0000000000..318fa55db7 --- /dev/null +++ b/databuilder/databuilder/publisher/elasticsearch_constants.py @@ -0,0 +1,12 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from amundsen_common.models.index_map import ( + DASHBOARD_ELASTICSEARCH_INDEX_MAPPING as DASHBOARD_INDEX_MAP, TABLE_INDEX_MAP, USER_INDEX_MAP, +) + +# Please use constants in amundsen_common.models.index_map directly. This file is only here +# for backwards compatibility. +TABLE_ELASTICSEARCH_INDEX_MAPPING = TABLE_INDEX_MAP +DASHBOARD_ELASTICSEARCH_INDEX_MAPPING = DASHBOARD_INDEX_MAP +USER_ELASTICSEARCH_INDEX_MAPPING = USER_INDEX_MAP diff --git a/databuilder/databuilder/publisher/elasticsearch_publisher.py b/databuilder/databuilder/publisher/elasticsearch_publisher.py new file mode 100644 index 0000000000..671a1d6f28 --- /dev/null +++ b/databuilder/databuilder/publisher/elasticsearch_publisher.py @@ -0,0 +1,147 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from typing import List + +from amundsen_common.models.index_map import TABLE_INDEX_MAP +from elasticsearch.exceptions import NotFoundError +from pyhocon import ConfigTree + +from databuilder.publisher.base_publisher import Publisher + +LOGGER = logging.getLogger(__name__) + + +################################################################################################## +# +# ElasticsearchPublisher is being deprecated in favor of using SearchMetadatatoElasticasearchTask +# which publishes ES metadata with mappings compatible with amundsensearch >= 4.0.0 +# +################################################################################################## + +class ElasticsearchPublisher(Publisher): + """ + Elasticsearch Publisher uses Bulk API to load data from JSON file. + A new index is created and data is uploaded into it. After the upload + is complete, index alias is swapped to point to new index from old index + and traffic is routed to new index. + + Old index is deleted after the alias swap is complete + """ + FILE_PATH_CONFIG_KEY = 'file_path' + FILE_MODE_CONFIG_KEY = 'mode' + + ELASTICSEARCH_CLIENT_CONFIG_KEY = 'client' + ELASTICSEARCH_DOC_TYPE_CONFIG_KEY = 'doc_type' + ELASTICSEARCH_NEW_INDEX_CONFIG_KEY = 'new_index' + ELASTICSEARCH_ALIAS_CONFIG_KEY = 'alias' + ELASTICSEARCH_MAPPING_CONFIG_KEY = 'mapping' + + # config to control how many max documents to publish at a time + ELASTICSEARCH_PUBLISHER_BATCH_SIZE = 'batch_size' + + DEFAULT_ELASTICSEARCH_INDEX_MAPPING = TABLE_INDEX_MAP + + def __init__(self) -> None: + super(ElasticsearchPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + self.conf = conf + + self.file_path = self.conf.get_string(ElasticsearchPublisher.FILE_PATH_CONFIG_KEY) + self.file_mode = self.conf.get_string(ElasticsearchPublisher.FILE_MODE_CONFIG_KEY, 'r') + + self.elasticsearch_type = self.conf.get_string(ElasticsearchPublisher.ELASTICSEARCH_DOC_TYPE_CONFIG_KEY) + self.elasticsearch_client = self.conf.get(ElasticsearchPublisher.ELASTICSEARCH_CLIENT_CONFIG_KEY) + self.elasticsearch_new_index = self.conf.get(ElasticsearchPublisher.ELASTICSEARCH_NEW_INDEX_CONFIG_KEY) + self.elasticsearch_alias = self.conf.get(ElasticsearchPublisher.ELASTICSEARCH_ALIAS_CONFIG_KEY) + + self.elasticsearch_mapping = self.conf.get(ElasticsearchPublisher.ELASTICSEARCH_MAPPING_CONFIG_KEY, + ElasticsearchPublisher.DEFAULT_ELASTICSEARCH_INDEX_MAPPING) + self.elasticsearch_batch_size = self.conf.get(ElasticsearchPublisher.ELASTICSEARCH_PUBLISHER_BATCH_SIZE, + 10000) + self.file_handler = open(self.file_path, self.file_mode) + + def _fetch_old_index(self) -> List[str]: + """ + Retrieve all indices that currently have {elasticsearch_alias} alias + :return: list of elasticsearch indices + """ + try: + indices = self.elasticsearch_client.indices.get_alias(self.elasticsearch_alias).keys() + return indices + except NotFoundError: + LOGGER.warn("Received index not found error from Elasticsearch. " + + "The index doesn't exist for a newly created ES. It's OK on first run.") + # return empty list on exception + return [] + + def publish_impl(self) -> None: + """ + Use Elasticsearch Bulk API to load data from file to a {new_index}. + After upload, swap alias from {old_index} to {new_index} in a atomic operation + to route traffic to {new_index} + """ + + LOGGER.warn('ElasticsearchPublisher is being deprecated in favor of using SearchMetadatatoElasticasearchTask\ + which publishes ES metadata with mappings compatible with amundsensearch >= 4.0.0') + + actions = [json.loads(line) for line in self.file_handler.readlines()] + # ensure new data exists + if not actions: + LOGGER.warning("received no data to upload to Elasticsearch!") + return + + # Convert object to json for elasticsearch bulk upload + # Bulk load JSON format is defined here: + # https://www.elastic.co/guide/en/elasticsearch/reference/6.2/docs-bulk.html + bulk_actions = [] + cnt = 0 + + # create new index with mapping + self.elasticsearch_client.indices.create(index=self.elasticsearch_new_index, body=self.elasticsearch_mapping) + + for action in actions: + index_row = dict(index=dict(_index=self.elasticsearch_new_index)) + action['resource_type'] = self.elasticsearch_type + + bulk_actions.append(index_row) + bulk_actions.append(action) + cnt += 1 + if cnt == self.elasticsearch_batch_size: + self.elasticsearch_client.bulk(bulk_actions) + LOGGER.info('Publish %i of records to ES', cnt) + cnt = 0 + bulk_actions = [] + + # Do the final bulk actions + if bulk_actions: + self.elasticsearch_client.bulk(bulk_actions) + + # fetch indices that have {elasticsearch_alias} as alias + elasticsearch_old_indices = self._fetch_old_index() + + # update alias to point to the new index + actions = [{"add": {"index": self.elasticsearch_new_index, "alias": self.elasticsearch_alias}}] + + # delete old indices + delete_actions = [{"remove_index": {"index": index}} for index in elasticsearch_old_indices] + actions.extend(delete_actions) + + update_action = {"actions": actions} + + # perform alias update and index delete in single atomic operation + self.elasticsearch_client.indices.update_aliases(update_action) + + def close(self) -> None: + """ + close the file handler + :return: + """ + if self.file_handler: + self.file_handler.close() + + def get_scope(self) -> str: + return 'publisher.elasticsearch' diff --git a/databuilder/databuilder/publisher/mysql_csv_publisher.py b/databuilder/databuilder/publisher/mysql_csv_publisher.py new file mode 100644 index 0000000000..53fdeb19a3 --- /dev/null +++ b/databuilder/databuilder/publisher/mysql_csv_publisher.py @@ -0,0 +1,214 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import time +from os import listdir +from os.path import ( + basename, isfile, join, splitext, +) +from typing import ( + Dict, List, Optional, Type, +) + +import pandas +from amundsen_rds.models import RDSModel +from amundsen_rds.models.base import Base +from pyhocon import ConfigFactory, ConfigTree +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from databuilder.publisher.base_publisher import Publisher + +LOGGER = logging.getLogger(__name__) + + +class MySQLCSVPublisher(Publisher): + """ + A Publisher takes the table record folder as input and publishes csv to MySQL. + The folder contains CSV file(s) for table records. + + The publish job works with rds models and SQLAlchemy ORM for data ingestion into MySQL. + For more information: + rds models: https://github.com/amundsen-io/amundsenrds + SQLAlchemy ORM: https://docs.sqlalchemy.org/en/13/orm/ + """ + # Config keys + # A directory that contains CSV files for records + RECORD_FILES_DIR = 'record_files_directory' + # It is used to provide unique tag to each record + JOB_PUBLISH_TAG = 'job_publish_tag' + + # Connection string + CONN_STRING = 'conn_string' + # If its value is true, SQLAlchemy engine will log all statements + ENGINE_ECHO = 'engine_echo' + # Additional arguments used for engine + CONNECT_ARGS = 'connect_args' + # A transaction size that determines how often it commits. + TRANSACTION_SIZE = 'transaction_size' + # A progress report frequency that determines how often it report the progress. + PROGRESS_REPORT_FREQUENCY = 'progress_report_frequency' + + _DEFAULT_CONFIG = ConfigFactory.from_dict({TRANSACTION_SIZE: 500, + PROGRESS_REPORT_FREQUENCY: 500, + ENGINE_ECHO: False}) + + def __init__(self) -> None: + super(MySQLCSVPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(MySQLCSVPublisher._DEFAULT_CONFIG) + + self._count: int = 0 + self._progress_report_frequency = conf.get_int(MySQLCSVPublisher.PROGRESS_REPORT_FREQUENCY) + self._record_files = self._list_files(conf, MySQLCSVPublisher.RECORD_FILES_DIR) + self._sorted_record_files = self._sort_record_files(self._record_files) + self._record_files_iter = iter(self._sorted_record_files) + + connect_args = {k: v for k, v in conf.get_config(MySQLCSVPublisher.CONNECT_ARGS, + default=ConfigTree()).items()} + self._engine = create_engine(conf.get_string(MySQLCSVPublisher.CONN_STRING), + echo=conf.get_bool(MySQLCSVPublisher.ENGINE_ECHO), + connect_args=connect_args) + self._session_factory = sessionmaker(bind=self._engine) + self._transaction_size = conf.get_int(MySQLCSVPublisher.TRANSACTION_SIZE) + + self._publish_tag: str = conf.get_string(MySQLCSVPublisher.JOB_PUBLISH_TAG) + if not self._publish_tag: + raise Exception(f'{MySQLCSVPublisher.JOB_PUBLISH_TAG} should not be empty') + + def _list_files(self, conf: ConfigTree, path_key: str) -> List[str]: + """ + List files from directory + :param conf: + :param path_key: + :return: List of file paths + """ + if path_key not in conf: + return [] + + path = conf.get_string(path_key) + return [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + def _sort_record_files(self, files: List[str]) -> List[str]: + """ + Sort record files in the order of table dependencies + :param files: + :return: + """ + sorted_table_names = [table.name for table in Base.metadata.sorted_tables] + return sorted(files, key=lambda file: sorted_table_names.index(self._get_table_name_from_file(file))) + + def _get_table_name_from_file(self, file: str) -> str: + """ + Get table name from file path + :param file: + :return: + """ + try: + filename = splitext(basename(file))[0] + table_name, _ = filename.rsplit('_', 1) + return table_name + except Exception as e: + LOGGER.exception(f'Error encountered while getting table name from file: {file}') + raise e + + def publish_impl(self) -> None: + """ + Publish records + :return: + """ + start = time.time() + + LOGGER.info(f'Publishing record files: {self._sorted_record_files}') + session = self._session_factory() + try: + while True: + try: + record_file = next(self._record_files_iter) + self._publish(record_file=record_file, session=session) + except StopIteration: + break + + LOGGER.info(f'Committed total {self._count} statements') + LOGGER.info(f'Successfully published. Elapsed: {time.time() - start} seconds') + except Exception as e: + LOGGER.exception('Failed to publish. Rolling back.') + session.rollback() + raise e + finally: + session.close() + + def _publish(self, record_file: str, session: Session) -> None: + """ + Iterate over each row of the given csv file and convert each record to a rds model instance. + Then the model instance will be inserted/updated in mysql. + :param record_file: + :param session: + :return: + """ + with open(record_file, 'r', encoding='utf8') as record_csv: + table_name = self._get_table_name_from_file(record_file) + table_model = self._get_model_from_table_name(table_name) + if not table_model: + raise RuntimeError(f'Failed to get model for table: {table_name}') + + for record_dict in pandas.read_csv(record_csv, na_filter=False).to_dict(orient='records'): + record = self._create_record(model=table_model, record_dict=record_dict) + session.merge(record) + self._execute(session) + session.commit() + + def _get_model_from_table_name(self, table_name: str) -> Optional[Type[RDSModel]]: + """ + Get rds model for the given table name + :param table_name: + :return: + """ + if hasattr(Base, '_decl_class_registry'): + models_generator = Base._decl_class_registry.values() + elif hasattr(Base, 'registry'): + models_generator = Base.registry._class_registry.values() + else: + raise Exception(f'Failed to get model for table: {table_name}') + + for model in models_generator: + if hasattr(model, '__tablename__') and model.__tablename__ == table_name: + return model + return None + + def _create_record(self, model: Type[RDSModel], record_dict: Dict) -> RDSModel: + """ + Convert the record dict to an instance of the given rds model + and set additional attributes for the record instance + :param model: + :param record_dict: + :return: + """ + record = model(**record_dict) + record.published_tag = self._publish_tag + record.publisher_last_updated_epoch_ms = int(time.time() * 1000) + return record + + def _execute(self, session: Session) -> None: + """ + Commit pending record changes + :param session: + :return: + """ + try: + self._count += 1 + if self._count > 1 and self._count % self._transaction_size == 0: + session.commit() + LOGGER.info(f'Committed {self._count} records so far') + + if self._count > 1 and self._count % self._progress_report_frequency == 0: + LOGGER.info(f'Processed {self._count} records so far') + + except Exception as e: + LOGGER.exception('Failed to commit changes') + raise e + + def get_scope(self) -> str: + return 'publisher.mysql' diff --git a/databuilder/databuilder/publisher/neo4j_csv_publisher.py b/databuilder/databuilder/publisher/neo4j_csv_publisher.py new file mode 100644 index 0000000000..e43c4e610d --- /dev/null +++ b/databuilder/databuilder/publisher/neo4j_csv_publisher.py @@ -0,0 +1,523 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import ctypes +import logging +import time +from io import open +from os import listdir +from os.path import isfile, join +from typing import ( + Dict, List, Optional, Set, +) + +import neo4j +import pandas +from jinja2 import Template +from neo4j import GraphDatabase, Transaction +from neo4j.api import ( + SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri, +) +from neo4j.exceptions import Neo4jError, TransientError +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.publisher.base_publisher import Publisher +from databuilder.publisher.neo4j_preprocessor import NoopRelationPreprocessor +from databuilder.publisher.publisher_config_constants import ( + Neo4jCsvPublisherConfigs, PublishBehaviorConfigs, PublisherConfigs, +) + +# Setting field_size_limit to solve the error below +# _csv.Error: field larger than field limit (131072) +# https://stackoverflow.com/a/54517228/5972935 +csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) + +# Config keys +# A directory that contains CSV files for nodes +NODE_FILES_DIR = PublisherConfigs.NODE_FILES_DIR +# A directory that contains CSV files for relationships +RELATION_FILES_DIR = PublisherConfigs.RELATION_FILES_DIR +# A end point for Neo4j e.g: bolt://localhost:9999 +NEO4J_END_POINT_KEY = Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY +# A transaction size that determines how often it commits. +NEO4J_TRANSACTION_SIZE = Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE +# A progress report frequency that determines how often it report the progress. +NEO4J_PROGRESS_REPORT_FREQUENCY = 'neo4j_progress_report_frequency' +# A boolean flag to make it fail if relationship is not created +NEO4J_RELATIONSHIP_CREATION_CONFIRM = 'neo4j_relationship_creation_confirm' + +NEO4J_MAX_CONN_LIFE_TIME_SEC = Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC + +# list of nodes that are create only, and not updated if match exists +NEO4J_CREATE_ONLY_NODES = Neo4jCsvPublisherConfigs.NEO4J_CREATE_ONLY_NODES + +# list of node labels that could attempt to be accessed simultaneously +NEO4J_DEADLOCK_NODE_LABELS = 'neo4j_deadlock_node_labels' + +NEO4J_USER = Neo4jCsvPublisherConfigs.NEO4J_USER +NEO4J_PASSWORD = Neo4jCsvPublisherConfigs.NEO4J_PASSWORD +# in Neo4j (v4.0+), we can create and use more than one active database at the same time +NEO4J_DATABASE_NAME = Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME + +# NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting +NEO4J_ENCRYPTED = Neo4jCsvPublisherConfigs.NEO4J_ENCRYPTED +# NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS +# cert against system CAs +NEO4J_VALIDATE_SSL = Neo4jCsvPublisherConfigs.NEO4J_VALIDATE_SSL + +# This will be used to provide unique tag to the node and relationship +JOB_PUBLISH_TAG = PublisherConfigs.JOB_PUBLISH_TAG + +# any additional fields that should be added to nodes and rels through config +ADDITIONAL_FIELDS = PublisherConfigs.ADDITIONAL_PUBLISHER_METADATA_FIELDS + +# Neo4j property name for published tag +PUBLISHED_TAG_PROPERTY_NAME = PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME + +# Neo4j property name for last updated timestamp +LAST_UPDATED_EPOCH_MS = PublisherConfigs.LAST_UPDATED_EPOCH_MS + +# A boolean flag to indicate if publisher_metadata (e.g. published_tag, +# publisher_last_updated_epoch_ms) +# will be included as properties of the Neo4j nodes +ADD_PUBLISHER_METADATA = PublishBehaviorConfigs.ADD_PUBLISHER_METADATA + +RELATION_PREPROCESSOR = 'relation_preprocessor' + +# CSV HEADER +# A header with this suffix will be pass to Neo4j statement without quote +UNQUOTED_SUFFIX = ':UNQUOTED' +# A header for Node label +NODE_LABEL_KEY = 'LABEL' +# A header for Node key +NODE_KEY_KEY = 'KEY' +# Required columns for Node +NODE_REQUIRED_KEYS = {NODE_LABEL_KEY, NODE_KEY_KEY} + +# Relationship relates two nodes together +# Start node label +RELATION_START_LABEL = 'START_LABEL' +# Start node key +RELATION_START_KEY = 'START_KEY' +# End node label +RELATION_END_LABEL = 'END_LABEL' +# Node node key +RELATION_END_KEY = 'END_KEY' +# Type for relationship (Start Node)->(End Node) +RELATION_TYPE = 'TYPE' +# Type for reverse relationship (End Node)->(Start Node) +RELATION_REVERSE_TYPE = 'REVERSE_TYPE' +# Required columns for Relationship +RELATION_REQUIRED_KEYS = {RELATION_START_LABEL, RELATION_START_KEY, + RELATION_END_LABEL, RELATION_END_KEY, + RELATION_TYPE, RELATION_REVERSE_TYPE} + +DEFAULT_CONFIG = ConfigFactory.from_dict({NEO4J_TRANSACTION_SIZE: 500, + NEO4J_PROGRESS_REPORT_FREQUENCY: 500, + NEO4J_RELATIONSHIP_CREATION_CONFIRM: False, + NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, + NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE, + ADDITIONAL_FIELDS: {}, + ADD_PUBLISHER_METADATA: True, + RELATION_PREPROCESSOR: NoopRelationPreprocessor()}) + +# transient error retries and sleep time +RETRIES_NUMBER = 5 +SLEEP_TIME = 2 + +LOGGER = logging.getLogger(__name__) + + +class Neo4jCsvPublisher(Publisher): + """ + A Publisher takes two folders for input and publishes to Neo4j. + One folder will contain CSV file(s) for Node where the other folder will contain CSV + file(s) for Relationship. + + Neo4j follows Label Node properties Graph and more information about this is in: + https://neo4j.com/docs/developer-manual/current/introduction/graphdb-concepts/ + + #TODO User UNWIND batch operation for better performance + """ + + def __init__(self) -> None: + super(Neo4jCsvPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DEFAULT_CONFIG) + + self._count: int = 0 + self._progress_report_frequency = conf.get_int(NEO4J_PROGRESS_REPORT_FREQUENCY) + self._node_files = self._list_files(conf, NODE_FILES_DIR) + self._node_files_iter = iter(self._node_files) + + self._relation_files = self._list_files(conf, RELATION_FILES_DIR) + self._relation_files_iter = iter(self._relation_files) + + uri = conf.get_string(NEO4J_END_POINT_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None) + encrypted_conf = conf.get(NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + self._driver = GraphDatabase.driver(**driver_args) + + self._db_name = conf.get_string(NEO4J_DATABASE_NAME) + self._session = self._driver.session(database=self._db_name) + + self._transaction_size = conf.get_int(NEO4J_TRANSACTION_SIZE) + self._confirm_rel_created = conf.get_bool(NEO4J_RELATIONSHIP_CREATION_CONFIRM) + + # config is list of node label. + # When set, this list specifies a list of nodes that shouldn't be updated, if exists + self.create_only_nodes = set(conf.get_list(NEO4J_CREATE_ONLY_NODES, default=[])) + self.deadlock_node_labels = set(conf.get_list(NEO4J_DEADLOCK_NODE_LABELS, default=[])) + self.labels: Set[str] = set() + self.publish_tag: str = conf.get_string(JOB_PUBLISH_TAG) + self.additional_fields: Dict = conf.get(ADDITIONAL_FIELDS) + self.add_publisher_metadata: bool = conf.get_bool(ADD_PUBLISHER_METADATA) + if self.add_publisher_metadata and not self.publish_tag: + raise Exception(f'{JOB_PUBLISH_TAG} should not be empty') + + self._relation_preprocessor = conf.get(RELATION_PREPROCESSOR) + + LOGGER.info('Publishing Node csv files %s, and Relation CSV files %s', + self._node_files, + self._relation_files) + + def _list_files(self, conf: ConfigTree, path_key: str) -> List[str]: + """ + List files from directory + :param conf: + :param path_key: + :return: List of file paths + """ + if path_key not in conf: + return [] + + path = conf.get_string(path_key) + return [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + def publish_impl(self) -> None: # noqa: C901 + """ + Publishes Nodes first and then Relations + :return: + """ + + start = time.time() + + LOGGER.info('Creating indices using Node files: %s', self._node_files) + for node_file in self._node_files: + self._create_indices(node_file=node_file) + + LOGGER.info('Publishing Node files: %s', self._node_files) + try: + tx = self._session.begin_transaction() + while True: + try: + node_file = next(self._node_files_iter) + tx = self._publish_node(node_file, tx=tx) + except StopIteration: + break + + LOGGER.info('Publishing Relationship files: %s', self._relation_files) + while True: + try: + relation_file = next(self._relation_files_iter) + tx = self._publish_relation(relation_file, tx=tx) + except StopIteration: + break + + tx.commit() + LOGGER.info('Committed total %i statements', self._count) + + # TODO: Add statsd support + LOGGER.info('Successfully published. Elapsed: %i seconds', time.time() - start) + except Exception as e: + LOGGER.exception('Failed to publish. Rolling back.') + if not tx.closed(): + tx.rollback() + raise e + + def get_scope(self) -> str: + return 'publisher.neo4j' + + def _create_indices(self, node_file: str) -> None: + """ + Go over the node file and try creating unique index + :param node_file: + :return: + """ + LOGGER.info('Creating indices. (Existing indices will be ignored)') + + with open(node_file, 'r', encoding='utf8') as node_csv: + for node_record in pandas.read_csv(node_csv, + na_filter=False).to_dict(orient='records'): + label = node_record[NODE_LABEL_KEY] + if label not in self.labels: + self._try_create_index(label) + self.labels.add(label) + + LOGGER.info('Indices have been created.') + + def _publish_node(self, node_file: str, tx: Transaction) -> Transaction: + """ + Iterate over the csv records of a file, each csv record transform to Merge statement + and will be executed. + All nodes should have a unique key, and this method will try to create unique index on + the LABEL when it sees first time within a job scope. + Example of Cypher query executed by this method: + MERGE (col_test_id1:Column {key: 'presto://gold.test_schema1/test_table1/test_id1'}) + ON CREATE SET col_test_id1.name = 'test_id1', + col_test_id1.order_pos = 2, + col_test_id1.type = 'bigint' + ON MATCH SET col_test_id1.name = 'test_id1', + col_test_id1.order_pos = 2, + col_test_id1.type = 'bigint' + + :param node_file: + :return: + """ + + with open(node_file, 'r', encoding='utf8') as node_csv: + for node_record in pandas.read_csv(node_csv, + na_filter=False).to_dict(orient="records"): + stmt = self.create_node_merge_statement(node_record=node_record) + params = self._create_props_param(node_record) + tx = self._execute_statement(stmt, tx, params) + return tx + + def is_create_only_node(self, node_record: dict) -> bool: + """ + Check if node can be updated + :param node_record: + :return: + """ + if self.create_only_nodes: + return node_record[NODE_LABEL_KEY] in self.create_only_nodes + else: + return False + + def create_node_merge_statement(self, node_record: dict) -> str: + """ + Creates node merge statement + :param node_record: + :return: + """ + template = Template(""" + MERGE (node:{{ LABEL }} {key: $KEY}) + ON CREATE SET {{ PROP_BODY }} + {% if update %} ON MATCH SET {{ PROP_BODY }} {% endif %} + """) + + prop_body = self._create_props_body(node_record, NODE_REQUIRED_KEYS, 'node') + + return template.render(LABEL=node_record["LABEL"], + PROP_BODY=prop_body, + update=(not self.is_create_only_node(node_record))) + + def _publish_relation(self, relation_file: str, tx: Transaction) -> Transaction: + """ + Creates relation between two nodes. + (In Amundsen, all relation is bi-directional) + + Example of Cypher query executed by this method: + MATCH (n1:Table {key: 'presto://gold.test_schema1/test_table1'}), + (n2:Column {key: 'presto://gold.test_schema1/test_table1/test_col1'}) + MERGE (n1)-[r1:COLUMN]->(n2)-[r2:BELONG_TO_TABLE]->(n1) + RETURN n1.key, n2.key + + :param relation_file: + :return: + """ + + if self._relation_preprocessor.is_perform_preprocess(): + LOGGER.info('Pre-processing relation with %s', self._relation_preprocessor) + + count = 0 + with open(relation_file, 'r', encoding='utf8') as relation_csv: + for rel_record in pandas.read_csv(relation_csv, + na_filter=False).to_dict(orient="records"): + # TODO not sure if deadlock on badge node arises in preporcessing or not + stmt, params = self._relation_preprocessor.preprocess_cypher( + start_label=rel_record[RELATION_START_LABEL], + end_label=rel_record[RELATION_END_LABEL], + start_key=rel_record[RELATION_START_KEY], + end_key=rel_record[RELATION_END_KEY], + relation=rel_record[RELATION_TYPE], + reverse_relation=rel_record[RELATION_REVERSE_TYPE]) + + if stmt: + tx = self._execute_statement(stmt, tx=tx, params=params) + count += 1 + + LOGGER.info('Executed pre-processing Cypher statement %i times', count) + + with open(relation_file, 'r', encoding='utf8') as relation_csv: + for rel_record in pandas.read_csv(relation_csv, na_filter=False).to_dict(orient="records"): + exception_exists = True + retries_for_exception = RETRIES_NUMBER + while exception_exists and retries_for_exception > 0: + try: + stmt = self.create_relationship_merge_statement(rel_record=rel_record) + params = self._create_props_param(rel_record) + tx = self._execute_statement(stmt, tx, params, + expect_result=self._confirm_rel_created) + exception_exists = False + except TransientError as e: + if rel_record[RELATION_START_LABEL] in self.deadlock_node_labels \ + or rel_record[RELATION_END_LABEL] in self.deadlock_node_labels: + time.sleep(SLEEP_TIME) + retries_for_exception -= 1 + else: + raise e + + return tx + + def create_relationship_merge_statement(self, rel_record: dict) -> str: + """ + Creates relationship merge statement + :param rel_record: + :return: + """ + template = Template(""" + MATCH (n1:{{ START_LABEL }} {key: $START_KEY}), (n2:{{ END_LABEL }} {key: $END_KEY}) + MERGE (n1)-[r1:{{ TYPE }}]->(n2)-[r2:{{ REVERSE_TYPE }}]->(n1) + {% if update_prop_body %} + ON CREATE SET {{ prop_body }} + ON MATCH SET {{ prop_body }} + {% endif %} + RETURN n1.key, n2.key + """) + + prop_body_r1 = self._create_props_body(rel_record, RELATION_REQUIRED_KEYS, 'r1') + prop_body_r2 = self._create_props_body(rel_record, RELATION_REQUIRED_KEYS, 'r2') + prop_body = ' , '.join([prop_body_r1, prop_body_r2]) + + return template.render(START_LABEL=rel_record["START_LABEL"], + END_LABEL=rel_record["END_LABEL"], + TYPE=rel_record["TYPE"], + REVERSE_TYPE=rel_record["REVERSE_TYPE"], + update_prop_body=prop_body_r1, + prop_body=prop_body) + + def _create_props_param(self, record_dict: dict) -> dict: + params = {} + for k, v in record_dict.items(): + if k.endswith(UNQUOTED_SUFFIX): + k = k[:-len(UNQUOTED_SUFFIX)] + + params[k] = v + return params + + def _create_props_body(self, + record_dict: dict, + excludes: Set, + identifier: str) -> str: + """ + Creates properties body with params required for resolving template. + + e.g: Note that node.key3 is not quoted if header has UNQUOTED_SUFFIX. + identifier.key1 = 'val1' , identifier.key2 = 'val2', identifier.key3 = val3 + + :param record_dict: A dict represents CSV row + :param excludes: set of excluded columns that does not need to be in properties + (e.g: KEY, LABEL ...) + :param identifier: identifier that will be used in CYPHER query as shown on above example + :return: Properties body for Cypher statement + """ + props = [] + for k, v in record_dict.items(): + if k in excludes: + continue + + if k.endswith(UNQUOTED_SUFFIX): + k = k[:-len(UNQUOTED_SUFFIX)] + + props.append(f'{identifier}.{k} = ${k}') + + if self.add_publisher_metadata: + props.append(f"{identifier}.{PUBLISHED_TAG_PROPERTY_NAME} = '{self.publish_tag}'") + props.append(f"{identifier}.{LAST_UPDATED_EPOCH_MS} = timestamp()") + + # add additional metatada fields from config + for k, v in self.additional_fields.items(): + val = v if isinstance(v, int) or isinstance(v, float) else f"'{v}'" + props.append(f"{identifier}.{k}= {val}") + + return ', '.join(props) + + def _execute_statement(self, + stmt: str, + tx: Transaction, + params: Optional[dict] = None, + expect_result: bool = False) -> Transaction: + """ + Executes statement against Neo4j. If execution fails, it rollsback and raise exception. + If 'expect_result' flag is True, it confirms if result object is not null. + :param stmt: + :param tx: + :param count: + :param expect_result: By having this True, it will validate if result object is not None. + :return: + """ + try: + LOGGER.debug('Executing statement: %s with params %s', stmt, params) + + result = tx.run(str(stmt), parameters=params) + if expect_result and not result.single(): + raise RuntimeError(f'Failed to executed statement: {stmt}') + + self._count += 1 + if self._count > 1 and self._count % self._transaction_size == 0: + tx.commit() + LOGGER.info(f'Committed {self._count} statements so far') + return self._session.begin_transaction() + + if self._count > 1 and self._count % self._progress_report_frequency == 0: + LOGGER.info(f'Processed {self._count} statements so far') + + return tx + except Exception as e: + LOGGER.exception('Failed to execute Cypher query') + if not tx.closed(): + tx.rollback() + raise e + + def _try_create_index(self, label: str) -> None: + """ + For any label seen first time for this publisher it will try to create unique index. + Neo4j ignores a second creation in 3.x, but raises an error in 4.x. + :param label: + :return: + """ + stmt = Template(""" + CREATE CONSTRAINT ON (node:{{ LABEL }}) ASSERT node.key IS UNIQUE + """).render(LABEL=label) + + LOGGER.info(f'Trying to create index for label {label} if not exist: {stmt}') + with self._driver.session(database=self._db_name) as session: + try: + session.run(stmt) + except Neo4jError as e: + if 'An equivalent constraint already exists' not in e.__str__(): + raise + # Else, swallow the exception, to make this function idempotent. diff --git a/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py b/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py new file mode 100644 index 0000000000..02769e22ec --- /dev/null +++ b/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py @@ -0,0 +1,373 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import ctypes +import logging +import time +from io import open +from typing import ( + Dict, List, Set, +) + +import neo4j +import pandas +from jinja2 import Template +from neo4j import GraphDatabase, Neo4jDriver +from neo4j.api import ( + SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri, +) +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.publisher.base_publisher import Publisher +from databuilder.publisher.publisher_config_constants import ( + Neo4jCsvPublisherConfigs, PublishBehaviorConfigs, PublisherConfigs, +) +from databuilder.utils.publisher_utils import ( + chunkify_list, create_neo4j_node_key_constraint, create_props_param, execute_neo4j_statement, get_props_body_keys, + list_files, +) + +# Setting field_size_limit to solve the error below +# _csv.Error: field larger than field limit (131072) +# https://stackoverflow.com/a/54517228/5972935 +csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) + +# Required columns for Node +NODE_REQUIRED_KEYS = {NODE_LABEL, NODE_KEY} +# Required columns for Relationship +RELATION_REQUIRED_KEYS = {RELATION_START_LABEL, RELATION_START_KEY, + RELATION_END_LABEL, RELATION_END_KEY, + RELATION_TYPE, RELATION_REVERSE_TYPE} + +DEFAULT_CONFIG = ConfigFactory.from_dict({Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE: 1000, + Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, + Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE, + PublishBehaviorConfigs.ADD_PUBLISHER_METADATA: True, + PublishBehaviorConfigs.PUBLISH_REVERSE_RELATIONSHIPS: True, + PublishBehaviorConfigs.PRESERVE_ADHOC_UI_DATA: True, + PublishBehaviorConfigs.PRESERVE_EMPTY_PROPS: True}) + +LOGGER = logging.getLogger(__name__) + + +class Neo4jCsvUnwindPublisher(Publisher): + """ + This publisher takes two folders for input and publishes to Neo4j. + One folder will contain CSV file(s) for Nodes where the other folder will contain CSV + file(s) for Relationships. + + The merge statements make use of the UNWIND clause to allow for batched params to be applied to each + statement. This improves performance by reducing the amount of individual transactions to the database, + and by allowing Neo4j to compile and cache the statement. + """ + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DEFAULT_CONFIG) + + self._count: int = 0 + self._node_files = list_files(conf, PublisherConfigs.NODE_FILES_DIR) + self._node_files_iter = iter(self._node_files) + + self._relation_files = list_files(conf, PublisherConfigs.RELATION_FILES_DIR) + self._relation_files_iter = iter(self._relation_files) + + self._driver = self._driver_init(conf) + self._db_name = conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME) + self._transaction_size = conf.get_int(Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE) + + # config is list of node label. + # When set, this list specifies a list of nodes that shouldn't be updated, if exists + self._create_only_nodes = set(conf.get_list(Neo4jCsvPublisherConfigs.NEO4J_CREATE_ONLY_NODES, default=[])) + self._labels: Set[str] = set() + self._publish_tag: str = conf.get_string(PublisherConfigs.JOB_PUBLISH_TAG) + self._additional_publisher_metadata_fields: Dict =\ + dict(conf.get(PublisherConfigs.ADDITIONAL_PUBLISHER_METADATA_FIELDS, default={})) + self._add_publisher_metadata: bool = conf.get_bool(PublishBehaviorConfigs.ADD_PUBLISHER_METADATA) + self._publish_reverse_relationships: bool = conf.get_bool(PublishBehaviorConfigs.PUBLISH_REVERSE_RELATIONSHIPS) + self._preserve_adhoc_ui_data = conf.get_bool(PublishBehaviorConfigs.PRESERVE_ADHOC_UI_DATA) + self._preserve_empty_props: bool = conf.get_bool(PublishBehaviorConfigs.PRESERVE_EMPTY_PROPS) + self._prop_types_to_configure: Dict =\ + dict(conf.get(Neo4jCsvPublisherConfigs.NEO4J_PROP_TYPES_TO_CONFIGURE, default={})) + if self._add_publisher_metadata and not self._publish_tag: + raise Exception(f'{PublisherConfigs.JOB_PUBLISH_TAG} should not be empty') + + LOGGER.info('Publishing Node csv files %s, and Relation CSV files %s', + self._node_files, + self._relation_files) + + def _driver_init(self, conf: ConfigTree) -> Neo4jDriver: + uri = conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': conf.get_int(Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_USER), + conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_PASSWORD)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = conf.get(Neo4jCsvPublisherConfigs.NEO4J_VALIDATE_SSL, None) + encrypted_conf = conf.get(Neo4jCsvPublisherConfigs.NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + driver = GraphDatabase.driver(**driver_args) + + try: + driver.verify_connectivity() + except Exception as e: + driver.close() + raise e + + return driver + + def publish_impl(self) -> None: # noqa: C901 + """ + Publishes Nodes first and then Relations + """ + start = time.time() + + for node_file in self._node_files: + self.pre_publish_node_file(node_file) + + LOGGER.info('Publishing Node files: %s', self._node_files) + while True: + try: + node_file = next(self._node_files_iter) + self._publish_node_file(node_file) + except StopIteration: + break + + for rel_file in self._relation_files: + self.pre_publish_rel_file(rel_file) + + LOGGER.info('Publishing Relationship files: %s', self._relation_files) + while True: + try: + relation_file = next(self._relation_files_iter) + self._publish_relation_file(relation_file) + except StopIteration: + break + + LOGGER.info('Committed total %i statements', self._count) + + # TODO: Add statsd support + LOGGER.info('Successfully published. Elapsed: %i seconds', time.time() - start) + + def get_scope(self) -> str: + return 'publisher.neo4j' + + # Can be overridden with custom action(s) + def pre_publish_node_file(self, node_file: str) -> None: + created_constraint_labels = create_neo4j_node_key_constraint(node_file, self._labels, + self._driver, self._db_name) + self._labels.union(created_constraint_labels) + + # Can be overridden with custom action(s) + def pre_publish_rel_file(self, rel_file: str) -> None: + pass + + def _publish_node_file(self, node_file: str) -> None: + with open(node_file, 'r', encoding='utf8') as node_csv: + csv_dataframe = pandas.read_csv(node_csv, na_filter=False) + all_node_records = csv_dataframe.to_dict(orient="records") + + # Get the first node label since they will be the same for all records in the file + merge_stmt = self._create_node_merge_statement(node_keys=csv_dataframe.columns.tolist(), + node_label=all_node_records[0][NODE_LABEL]) + + self._write_transactions(merge_stmt, all_node_records) + + def _create_node_merge_statement(self, node_keys: list, node_label: str) -> str: + template = Template(""" + UNWIND $batch AS row + MERGE (node:{{ LABEL }} {key: row.KEY}) + ON CREATE SET {{ PROPS_BODY_CREATE }} + {% if update %} ON MATCH SET {{ PROPS_BODY_UPDATE }} {% endif %} + """) + + props_body_create = self._create_props_body(get_props_body_keys(node_keys, + NODE_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'node') + + props_body_update = props_body_create + if self._preserve_adhoc_ui_data: + props_body_update = self._create_props_body(get_props_body_keys(node_keys, + NODE_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'node', True) + + return template.render(LABEL=node_label, + PROPS_BODY_CREATE=props_body_create, + PROPS_BODY_UPDATE=props_body_update, + update=(node_label not in self._create_only_nodes)) + + def _publish_relation_file(self, relation_file: str) -> None: + with open(relation_file, 'r', encoding='utf8') as relation_csv: + csv_dataframe = pandas.read_csv(relation_csv, na_filter=False) + all_rel_records = csv_dataframe.to_dict(orient="records") + + # Get the first relation labels since they will be the same for all records in the file + merge_stmt = self._create_relationship_merge_statement( + rel_keys=csv_dataframe.columns.tolist(), + start_label=all_rel_records[0][RELATION_START_LABEL], + end_label=all_rel_records[0][RELATION_END_LABEL], + relation_type=all_rel_records[0][RELATION_TYPE], + relation_reverse_type=all_rel_records[0][RELATION_REVERSE_TYPE] + ) + + self._write_transactions(merge_stmt, all_rel_records) + + def _create_relationship_merge_statement(self, + rel_keys: list, + start_label: str, + end_label: str, + relation_type: str, + relation_reverse_type: str) -> str: + template = Template(""" + UNWIND $batch as row + MATCH (n1:{{ START_LABEL }} {key: row.START_KEY}), (n2:{{ END_LABEL }} {key: row.END_KEY}) + {% if publish_reverse_relationships %} + MERGE (n1)-[r1:{{ TYPE }}]->(n2)-[r2:{{ REVERSE_TYPE }}]->(n1) + {% elif not publish_reverse_relationships and has_key %} + MERGE (n1)-[r1:{{ TYPE }} {key: row.key}]->(n2) + {% else %} + MERGE (n1)-[r1:{{ TYPE }}]->(n2) + {% endif %} + {% if update_props_body %} + ON CREATE SET {{ props_body_create }} + ON MATCH SET {{ props_body_update }} + {% endif %} + RETURN n1.key, n2.key + """) + + props_body_template = Template("""{{ props_body_r1 }} , {{ props_body_r2 }}""") + + props_body_r1 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'r1') + props_body_r2 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'r2') + if self._publish_reverse_relationships: + props_body_create = props_body_template.render(props_body_r1=props_body_r1, props_body_r2=props_body_r2) + else: + props_body_create = props_body_r1 + + props_body_update = props_body_create + if self._preserve_adhoc_ui_data: + props_body_r1 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'r1', True) + props_body_r2 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'r2', True) + if self._publish_reverse_relationships: + props_body_update = props_body_template.render(props_body_r1=props_body_r1, props_body_r2=props_body_r2) + else: + props_body_update = props_body_r1 + + return template.render(START_LABEL=start_label, + END_LABEL=end_label, + publish_reverse_relationships=self._publish_reverse_relationships, + has_key='key' in rel_keys, + TYPE=relation_type, + REVERSE_TYPE=relation_reverse_type, + update_props_body=props_body_r1, + props_body_create=props_body_create, + props_body_update=props_body_update) + + def _create_props_body(self, + record_keys: Set, + identifier: str, + rename_id_to_preserve_ui_data: bool = False) -> str: + """ + Creates properties body with params required for resolving template. + + e.g: Note that node.key3 is not quoted if header has UNQUOTED_SUFFIX. + identifier.key1 = 'val1' , identifier.key2 = 'val2', identifier.key3 = val3 + + :param record_keys: a list of keys for a CSV row + :param identifier: identifier that will be used in CYPHER query as shown on above example + :param rename_id_to_preserve_ui_data: specifies whether to null out the identifier to prevent it from updating + :return: Properties body for Cypher statement + """ + # For SET, if the evaluated expression is null, no action is performed. I.e. `SET (null).foo = 5` is a noop. + # See https://neo4j.com/docs/cypher-manual/current/clauses/set/ + if rename_id_to_preserve_ui_data: + identifier = f""" + (CASE WHEN {identifier}.{PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME} IS NOT NULL + THEN {identifier} ELSE null END) + """ + + template = Template(""" + {% for k in record_keys %} + {% if preserve_empty_props %} + {% if k in prop_types_to_configure %} + {{ identifier }}.{{ k }} = {{ prop_types_to_configure[k] }}(row.{{ k }}) + {% else %} + {{ identifier }}.{{ k }} = row.{{ k }} + {% endif %} + {% else %} + {% if k in prop_types_to_configure %} + {{ identifier }}.{{ k }} = + (CASE row.{{ k }} WHEN '' THEN NULL ELSE {{ prop_types_to_configure[k] }}(row.{{ k }}) END) + {% else %} + {{ identifier }}.{{ k }} = (CASE row.{{ k }} WHEN '' THEN NULL ELSE row.{{ k }} END) + {% endif %} + {% endif %} + {{ ", " if not loop.last else "" }} + {% endfor %} + {% if record_keys and add_publisher_metadata %} + , + {% endif %} + {% if add_publisher_metadata %} + {% if published_tag_prop in prop_types_to_configure %} + {{ identifier }}.{{ published_tag_prop }} = + {{ prop_types_to_configure[published_tag_prop] }}('{{ publish_tag }}'), + {% else %} + {{ identifier }}.{{ published_tag_prop }} = '{{ publish_tag }}', + {% endif %} + {{ identifier }}.{{ last_updated_prop }} = timestamp() + {% endif %} + """) + + props_body = template.render(record_keys=record_keys, + preserve_empty_props=self._preserve_empty_props, + prop_types_to_configure=self._prop_types_to_configure, + identifier=identifier, + add_publisher_metadata=self._add_publisher_metadata, + published_tag_prop=PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME, + publish_tag=self._publish_tag, + last_updated_prop=PublisherConfigs.LAST_UPDATED_EPOCH_MS) + return props_body.strip() + + def _write_transactions(self, + stmt: str, + records: List[dict]) -> None: + for chunk in chunkify_list(records, self._transaction_size): + params_list = [] + for record in chunk: + params_list.append(create_props_param(record, self._additional_publisher_metadata_fields)) + + with self._driver.session(database=self._db_name) as session: + session.write_transaction(execute_neo4j_statement, stmt, {'batch': params_list}) + + self._count += len(params_list) + LOGGER.info(f'Committed {self._count} rows so far') diff --git a/databuilder/databuilder/publisher/neo4j_preprocessor.py b/databuilder/databuilder/publisher/neo4j_preprocessor.py new file mode 100644 index 0000000000..203d67311d --- /dev/null +++ b/databuilder/databuilder/publisher/neo4j_preprocessor.py @@ -0,0 +1,205 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +import textwrap +from typing import ( + Dict, List, Optional, Tuple, +) + +LOGGER = logging.getLogger(__name__) + + +class RelationPreprocessor(object, metaclass=abc.ABCMeta): + """ + A Preprocessor for relations. Prior to publish Neo4j relations, RelationPreprocessor will be used for + pre-processing. + Neo4j Publisher will iterate through relation file and call preprocess_cypher to perform any pre-process requested. + + For example, if you need current job's relation data to be desired state, you can add delete statement in + pre-process_cypher method. With preprocess_cypher defined, and with long transaction size, Neo4j publisher will + atomically apply desired state. + + + """ + + def preprocess_cypher(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> Optional[Tuple[str, Dict[str, str]]]: + """ + Provides a Cypher statement that will be executed before publishing relations. + :param start_label: + :param end_label: + :param start_key: + :param end_key: + :param relation: + :param reverse_relation: + :return: + """ + if self.filter(start_label=start_label, + end_label=end_label, + start_key=start_key, + end_key=end_key, + relation=relation, + reverse_relation=reverse_relation): + return self.preprocess_cypher_impl(start_label=start_label, + end_label=end_label, + start_key=start_key, + end_key=end_key, + relation=relation, + reverse_relation=reverse_relation) + return None + + @abc.abstractmethod + def preprocess_cypher_impl(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> Tuple[str, Dict[str, str]]: + """ + Provides a Cypher statement that will be executed before publishing relations. + :param start_label: + :param end_label: + :param relation: + :param reverse_relation: + :return: A Cypher statement + """ + pass + + def filter(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> bool: + """ + A method that filters pre-processing in record level. Returns True if it needs preprocessing, otherwise False. + :param start_label: + :param end_label: + :param start_key: + :param end_key: + :param relation: + :param reverse_relation: + :return: bool. True if it needs preprocessing, otherwise False. + """ + return True + + @abc.abstractmethod + def is_perform_preprocess(self) -> bool: + """ + A method for Neo4j Publisher to determine whether to perform pre-processing or not. Regard this method as a + global filter. + :return: True if you want to enable the pre-processing. + """ + pass + + +class NoopRelationPreprocessor(RelationPreprocessor): + + def preprocess_cypher_impl(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> Tuple[str, Dict[str, str]]: + return '', {} + + def is_perform_preprocess(self) -> bool: + return False + + +class DeleteRelationPreprocessor(RelationPreprocessor): + """ + A Relation Pre-processor that delete relationship before Neo4jPublisher publishes relations. + + Example use case: Take an example of an external privacy service trying to push personal identifiable + identification (PII) tag into Amundsen. It is fine to push set of PII tags for the first push, but it becomes a + challenge when it comes to following update as external service does not know current PII state in Amundsen. + + The easy solution is for external service to know desired state (certain columns should have certain PII tags), + and push that information. + Now the challenge is how Amundsen apply desired state. This is where DeleteRelationPreprocessor comes into the + picture. We can utilize DeleteRelationPreprocessor to let it delete certain relations in the job, + and let Neo4jPublisher update to desired state. Should there be a small window (between delete and update) that + Amundsen data is not complete, you can increase Neo4jPublisher's transaction size to make it atomic. However, + note that you should not set transaction size too big as Neo4j uses memory to store transaction and this use case + is proper for small size of batch job. + """ + RELATION_MERGE_TEMPLATE = textwrap.dedent(""" + MATCH (n1:{start_label} {{key: $start_key }})-[r]-(n2:{end_label} {{key: $end_key }}) + {where_clause} + WITH r LIMIT 2 + DELETE r + RETURN count(*) as count; + """) + + def __init__(self, + label_tuples: Optional[List[Tuple[str, str]]] = None, + where_clause: str = '') -> None: + super(DeleteRelationPreprocessor, self).__init__() + self._label_tuples = set(label_tuples) if label_tuples else set() + + reversed_label_tuples = [(t2, t1) for t1, t2 in self._label_tuples] + self._label_tuples.update(reversed_label_tuples) + self._where_clause = where_clause + + def preprocess_cypher_impl(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> Tuple[str, Dict[str, str]]: + """ + Provides DELETE Relation Cypher query on specific relation. + :param start_label: + :param end_label: + :param start_key: + :param end_key: + :param relation: + :param reverse_relation: + :return: + """ + + if not (start_label or end_label or start_key or end_key): + raise Exception(f'all labels and keys are required: {locals()}') + + params = {'start_key': start_key, 'end_key': end_key} + return DeleteRelationPreprocessor.RELATION_MERGE_TEMPLATE.format(start_label=start_label, + end_label=end_label, + where_clause=self._where_clause), params + + def is_perform_preprocess(self) -> bool: + return True + + def filter(self, + start_label: str, + end_label: str, + start_key: str, + end_key: str, + relation: str, + reverse_relation: str) -> bool: + """ + If pair of labels is what client requested passed through label_tuples, filter will return True meaning that + it needs to be pre-processed. + :param start_label: + :param end_label: + :param start_key: + :param end_key: + :param relation: + :param reverse_relation: + :return: bool. True if it needs preprocessing, otherwise False. + """ + if self._label_tuples and (start_label, end_label) not in self._label_tuples: + return False + + return True diff --git a/databuilder/databuilder/publisher/neptune_csv_publisher.py b/databuilder/databuilder/publisher/neptune_csv_publisher.py new file mode 100644 index 0000000000..5ae60a3357 --- /dev/null +++ b/databuilder/databuilder/publisher/neptune_csv_publisher.py @@ -0,0 +1,176 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import logging +import os +import time +from os import listdir +from os.path import isfile, join +from typing import List, Tuple + +from amundsen_gremlin.neptune_bulk_loader.api import NeptuneBulkLoaderApi, NeptuneBulkLoaderLoadStatusErrorLogEntry +from boto3.session import Session +from pyhocon import ConfigTree + +from databuilder.publisher.base_publisher import Publisher + +LOGGER = logging.getLogger(__name__) + + +class NeptuneCSVPublisher(Publisher): + """ + This Publisher takes two folders for input and publishes to Neptune. + One folder will contain CSV file(s) for Node where the other folder will contain CSV file(s) for Relationship. + + This publisher uses the bulk api found in + https://github.com/amundsen-io/amundsengremlin/blob/master/amundsen_gremlin/neptune_bulk_loader/api.py + + which is a client for the the api found + https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load.html + https://docs.aws.amazon.com/neptune/latest/userguide/bulk-load-tutorial-format-gremlin.html + """ + + # A directory that contains CSV files for nodes + NODE_FILES_DIR = 'node_files_directory' + # A directory that contains CSV files for relationships + RELATION_FILES_DIR = 'relation_files_directory' + + # --- AWS CONFIGURATION --- + # S3 bucket to upload files to + AWS_S3_BUCKET_NAME = 'bucket_name' + # S3 location where amundsen data can be exported to and Neptune can access + AWS_BASE_S3_DATA_PATH = 'base_amundsen_data_path' + + NEPTUNE_HOST = 'neptune_host' + + # AWS CONFIGURATION + AWS_REGION = 'aws_region' + AWS_ACCESS_KEY = 'aws_access_key' + AWS_SECRET_ACCESS_KEY = 'aws_secret_access_key' + AWS_SESSION_TOKEN = 'aws_session_token' + AWS_IAM_ROLE_NAME = 'aws_iam_role_name' + AWS_STS_ENDPOINT_URL = 'aws_sts_endpoint_url' + FAIL_ON_ERROR = "fail_on_error" + STATUS_POLLING_PERIOD = "status_polling_period" + + def __init__(self) -> None: + super(NeptuneCSVPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + self._boto_session = Session( + aws_access_key_id=conf.get_string(NeptuneCSVPublisher.AWS_ACCESS_KEY, default=None), + aws_secret_access_key=conf.get_string(NeptuneCSVPublisher.AWS_SECRET_ACCESS_KEY, default=None), + aws_session_token=conf.get_string(NeptuneCSVPublisher.AWS_SESSION_TOKEN, default=None), + region_name=conf.get_string(NeptuneCSVPublisher.AWS_REGION, default=None) + ) + + self.node_files_dir = conf.get_string(NeptuneCSVPublisher.NODE_FILES_DIR) + self.relation_files_dir = conf.get_string(NeptuneCSVPublisher.RELATION_FILES_DIR) + + self._neptune_host = conf.get_string(NeptuneCSVPublisher.NEPTUNE_HOST) + + neptune_bulk_endpoint_uri = "wss://{host}/gremlin".format( + host=self._neptune_host + ) + + self.bucket_name = conf.get_string(NeptuneCSVPublisher.AWS_S3_BUCKET_NAME) + + self.neptune_api_client = NeptuneBulkLoaderApi( + session=self._boto_session, + endpoint_uri=neptune_bulk_endpoint_uri, + s3_bucket_name=self.bucket_name, + iam_role_name=conf.get_string(NeptuneCSVPublisher.AWS_IAM_ROLE_NAME, default=None), + sts_endpoint=conf.get_string(NeptuneCSVPublisher.AWS_STS_ENDPOINT_URL, default=None), + ) + self.base_amundsen_data_path = conf.get_string(NeptuneCSVPublisher.AWS_BASE_S3_DATA_PATH) + self.fail_on_error = conf.get_bool(NeptuneCSVPublisher.FAIL_ON_ERROR, default=False) + self.status_polling_period = conf.get_int(NeptuneCSVPublisher.STATUS_POLLING_PERIOD, default=5) + + def publish_impl(self) -> None: + if not self._is_upload_required(): + return + + datetime_portion = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + s3_folder_location = "{base_directory}/{datetime_portion}".format( + base_directory=self.base_amundsen_data_path, + datetime_portion=datetime_portion, + ) + + self.upload_files(s3_folder_location) + + bulk_upload_response = self.neptune_api_client.load( + s3_object_key=s3_folder_location, + failOnError=self.fail_on_error + ) + + try: + load_id = bulk_upload_response['payload']['loadId'] + except KeyError: + raise Exception("Failed to load csv. Response: {0}".format(str(bulk_upload_response))) + + load_status = "LOAD_NOT_STARTED" + all_errors: List[NeptuneBulkLoaderLoadStatusErrorLogEntry] = [] + while load_status in ("LOAD_IN_PROGRESS", "LOAD_NOT_STARTED", "LOAD_IN_QUEUE"): + time.sleep(self.status_polling_period) + load_status, errors = self._poll_status(load_id) + all_errors.extend(errors) + + for error in all_errors: + exception_message = """ + Error Code: {error_code} + Error Message: {error_message} + Failed File: {s3_path} + """.format( + error_code=error.get('errorCode'), + error_message=error.get('errorMessage'), + s3_path=error.get('fileName') + ) + LOGGER.exception(exception_message) + + def _poll_status(self, load_id: str) -> Tuple[str, List[NeptuneBulkLoaderLoadStatusErrorLogEntry]]: + load_status_response = self.neptune_api_client.load_status( + load_id=load_id, + errors=True + ) + load_status_payload = load_status_response.get('payload', {}) + try: + load_status = load_status_payload['overallStatus']['status'] + except KeyError: + raise Exception("Failed to check status of {0} response: {1}".format( + str(load_id), + repr(load_status_response) + )) + return load_status, load_status_payload.get('errors', {}).get('errorLogs', []) + + def _get_file_paths(self) -> List[str]: + node_names = [ + join(self.node_files_dir, f) for f in listdir(self.node_files_dir) + if isfile(join(self.node_files_dir, f)) + ] + edge_names = [ + join(self.relation_files_dir, f) for f in listdir(self.relation_files_dir) + if isfile(join(self.relation_files_dir, f)) + ] + return node_names + edge_names + + def _is_upload_required(self) -> bool: + file_names = self._get_file_paths() + return len(file_names) > 0 + + def upload_files(self, s3_folder_location: str) -> None: + file_paths = self._get_file_paths() + for file_location in file_paths: + with open(file_location, 'rb') as file_csv: + file_name = os.path.basename(file_location) + s3_object_key = "{s3_folder_location}/{file_name}".format( + s3_folder_location=s3_folder_location, + file_name=file_name + ) + self.neptune_api_client.upload( + f=file_csv, + s3_object_key=s3_object_key + ) + + def get_scope(self) -> str: + return 'publisher.neptune_csv_publisher' diff --git a/databuilder/databuilder/publisher/publisher_config_constants.py b/databuilder/databuilder/publisher/publisher_config_constants.py new file mode 100644 index 0000000000..f354656f0e --- /dev/null +++ b/databuilder/databuilder/publisher/publisher_config_constants.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +class PublisherConfigs: + # A directory that contains CSV files for nodes + NODE_FILES_DIR = 'node_files_directory' + # A directory that contains CSV files for relationships + RELATION_FILES_DIR = 'relation_files_directory' + + # A CSV header with this suffix will be passed to the statement without quotes + UNQUOTED_SUFFIX = ':UNQUOTED' + + # This will be used to provide unique tag to the node and relationship + JOB_PUBLISH_TAG = 'job_publish_tag' + + # any additional fields that should be added to nodes and rels through config + ADDITIONAL_PUBLISHER_METADATA_FIELDS = 'additional_publisher_metadata_fields' + + # Property name for published tag + PUBLISHED_TAG_PROPERTY_NAME = 'published_tag' + # Property name for last updated timestamp + LAST_UPDATED_EPOCH_MS = 'publisher_last_updated_epoch_ms' + + +class PublishBehaviorConfigs: + # A boolean flag to indicate if publisher_metadata (e.g. published_tag, + # publisher_last_updated_epoch_ms) + # will be included as properties of the nodes + ADD_PUBLISHER_METADATA = 'add_publisher_metadata' + + # NOTE: Do not use this unless you have a specific use case for it. Amundsen expects two way relationships, and + # the default value should be set to true to publish relations in both directions. If it is overridden and set + # to false, reverse relationships will not be published. + PUBLISH_REVERSE_RELATIONSHIPS = 'publish_reverse_relationships' + + # If enabled, stops the publisher from updating a node or relationship + # created via the UI, e.g. a description or owner added manually by an Amundsen user. + # Such nodes/relationships will not have a 'published_tag' property that is set by databuilder. + PRESERVE_ADHOC_UI_DATA = 'preserve_adhoc_ui_data' + + # If enabled, the default behavior will continue to publish properties with empty values. + # If False, empty properties will be set to NULL and will not show up on the node or relation. + PRESERVE_EMPTY_PROPS = 'preserve_empty_props' + + +class Neo4jCsvPublisherConfigs: + # A end point for Neo4j e.g: bolt://localhost:9999 + NEO4J_END_POINT_KEY = 'neo4j_endpoint' + # A transaction size that determines how often it commits. + NEO4J_TRANSACTION_SIZE = 'neo4j_transaction_size' + + NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec' + + # list of nodes that are create only, and not updated if match exists + NEO4J_CREATE_ONLY_NODES = 'neo4j_create_only_nodes' + + NEO4J_USER = 'neo4j_user' + NEO4J_PASSWORD = 'neo4j_password' + # in Neo4j (v4.0+), we can create and use more than one active database at the same time + NEO4J_DATABASE_NAME = 'neo4j_database' + + # NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting + NEO4J_ENCRYPTED = 'neo4j_encrypted' + # NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS + # cert against system CAs + NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' + + # This should be a dict using property names as keys mapped to the function name used to configure a specific + # type for that property. The values of the properties should be in the correct format that the function accepts. + # Example: a config of {'start_time': 'datetime', 'publish_tag': 'date'} where the property values are in the + # format T
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
task = Neo4jStalenessRemovalTask()
+job_config_dict = {
+    'job.identifier': 'remove_stale_data_job',
+    'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint,
+    'task.remove_stale_data.neo4j_user': neo4j_user,
+    'task.remove_stale_data.neo4j_password': neo4j_password,
+    'task.remove_stale_data.staleness_max_pct': 10,
+    'task.remove_stale_data.target_nodes': ['Table', 'Column'],
+    'task.remove_stale_data.job_publish_tag': '2020-03-31',
+    'task.remove_stale_data.retain_data_with_no_publisher_metadata': True
+}
+job_config = ConfigFactory.from_dict(job_config_dict)
+job = DefaultJob(conf=job_config, task=task)
+job.launch()
+
+ +

Using “published_tag” to remove stale data

+

Use published_tag to remove stale data, when it is certain that non-matching tag is stale once all the ingestion is completed. For example, suppose that you use current date (or execution date in Airflow) as a published_tag, “2020-03-31”. Once Databuilder ingests all tables and all columns, all table nodes and column nodes should have published_tag as “2020-03-31”. It is safe to assume that table nodes and column nodes whose published_tag is different – such as “2020-03-30” or “2020-02-10” – means that it is deleted from the source metadata. You can use Neo4jStalenessRemovalTask to delete those stale data.

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
task = Neo4jStalenessRemovalTask()
+job_config_dict = {
+    'job.identifier': 'remove_stale_data_job',
+    'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint,
+    'task.remove_stale_data.neo4j_user': neo4j_user,
+    'task.remove_stale_data.neo4j_password': neo4j_password,
+    'task.remove_stale_data.staleness_max_pct': 10,
+    'task.remove_stale_data.target_nodes': ['Table', 'Column'],
+    'task.remove_stale_data.job_publish_tag': '2020-03-31'
+}
+job_config = ConfigFactory.from_dict(job_config_dict)
+job = DefaultJob(conf=job_config, task=task)
+job.launch()
+
+ +

Note that there’s protection mechanism, staleness_max_pct, that protect your data being wiped out when something is clearly wrong. “staleness_max_pct” basically first measure the proportion of elements that will be deleted and if it exceeds threshold per type ( 10% on the configuration above ), the deletion won’t be executed and the task aborts.

+

Using “publisher_last_updated_epoch_ms” to remove stale data

+

You can think this approach as TTL based eviction. This is particularly useful when there are multiple ingestion pipelines and you cannot be sure when all ingestion is done. In this case, you might still can say that if specific node or relation has not been published past 3 days, it’s stale data.

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
task = Neo4jStalenessRemovalTask()
+job_config_dict = {
+    'job.identifier': 'remove_stale_data_job',
+    'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint,
+    'task.remove_stale_data.neo4j_user': neo4j_user,
+    'task.remove_stale_data.neo4j_password': neo4j_password,
+    'task.remove_stale_data.staleness_max_pct': 10,
+    'task.remove_stale_data.target_relations': ['READ', 'READ_BY'],
+    'task.remove_stale_data.milliseconds_to_expire': 86400000 * 3
+}
+job_config = ConfigFactory.from_dict(job_config_dict)
+job = DefaultJob(conf=job_config, task=task)
+job.launch()
+
+ +

Above configuration is trying to delete stale usage relation (READ, READ_BY), by deleting READ or READ_BY relation that has not been published past 3 days. If number of elements to be removed is more than 10% per type, this task will be aborted without executing any deletion.

+

Using node and relation conditions to remove stale data

+

You may want to remove stale nodes and relations that meet certain conditions rather than all of a given type. To do this, you can specify the inputs to be a list of TargetWithCondition objects that each define a target type and a condition. Only stale nodes or relations of that type and that meet the condition will be removed when using this type of input.

+

Node conditions can make use of the predefined variable target which represents the node. Relation conditions can include the variables target, start_node, and end_node where target represents the relation and start_node/end_node represent the nodes on either side of the target relation. For some examples of conditions see below.

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
from databuilder.task.neo4j_staleness_removal_task import TargetWithCondition
+
+task = Neo4jStalenessRemovalTask()
+job_config_dict = {
+    'job.identifier': 'remove_stale_data_job',
+    'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint,
+    'task.remove_stale_data.neo4j_user': neo4j_user,
+    'task.remove_stale_data.neo4j_password': neo4j_password,
+    'task.remove_stale_data.staleness_max_pct': 10,
+    'task.remove_stale_data.target_nodes': [TargetWithCondition('Table', '(target)-[:COLUMN]->(:Column)'),  # All Table nodes that have a directional COLUMN relation to a Column node
+                                            TargetWithCondition('Column', '(target)-[]-(:Table) AND target.name=\'column_name\'')],  # All Column nodes named 'column_name' that have some relation to a Table node
+    'task.remove_stale_data.target_relations': [TargetWithCondition('COLUMN', '(start_node:Table)-[target]->(end_node:Column)'),  # All COLUMN relations that connect from a Table node to a Column node
+                                                TargetWithCondition('COLUMN', '(start_node:Column)-[target]-(end_node)')],  # All COLUMN relations that connect any direction between a Column node and another node
+    'task.remove_stale_data.milliseconds_to_expire': 86400000 * 3
+}
+job_config = ConfigFactory.from_dict(job_config_dict)
+job = DefaultJob(conf=job_config, task=task)
+job.launch()
+
+ +

You can include multiple inputs of the same type with different conditions as seen in the target_relations list above. Attribute checks can also be added as shown in the target_nodes list.

+

Dry run

+

Deletion is always scary and it’s better to perform dryrun before put this into action. You can use Dry run to see what sort of Cypher query will be executed.

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
task = Neo4jStalenessRemovalTask()
+job_config_dict = {
+    'job.identifier': 'remove_stale_data_job',
+    'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint,
+    'task.remove_stale_data.neo4j_user': neo4j_user,
+    'task.remove_stale_data.neo4j_password': neo4j_password,
+    'task.remove_stale_data.staleness_max_pct': 10,
+    'task.remove_stale_data.target_relations': ['READ', 'READ_BY'],
+    'task.remove_stale_data.milliseconds_to_expire': 86400000 * 3
+    'task.remove_stale_data.dry_run': True
+}
+job_config = ConfigFactory.from_dict(job_config_dict)
+job = DefaultJob(conf=job_config, task=task)
+job.launch()
+
+ + + + + + + + + + + + + + + + + + + + +
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/databuilder/requirements-dev.txt b/databuilder/requirements-dev.txt new file mode 100644 index 0000000000..fbdaf8e2d7 --- /dev/null +++ b/databuilder/requirements-dev.txt @@ -0,0 +1,22 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +# Common dependencies for code quality control (testing, linting, static checks, etc.) --------------------------------- + +flake8>=3.9.2 +flake8-tidy-imports>=4.3.0 +isort[colors]~=5.8.0 +mock>=4.0.3 +mypy>=1.9.0 +pytest>=6.2.4 +pytest-cov>=2.12.0 +pytest-env>=0.6.2 +pytest-mock>=3.6.1 +typed-ast>=1.4.3 +pyspark==3.0.1 +types-mock>=5.1.0.3 +types-protobuf>=4.24.0.4 +types-python-dateutil>=2.8.19.14 +types-pytz>=2023.3.1.1 +types-requests<2.31.0.7 +types-setuptools>=69.0.0.0 diff --git a/databuilder/requirements.txt b/databuilder/requirements.txt new file mode 100644 index 0000000000..5d246a9164 --- /dev/null +++ b/databuilder/requirements.txt @@ -0,0 +1,30 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +elasticsearch>=6.2.0,<8.0 +elasticsearch-dsl==7.4.0 +neo4j-driver>=4.4.5,<5.0 +requests>=2.25.0,<3.0 + +freezegun>=1.1.0 +atomicwrites>=1.1.5 +more-itertools>=4.2.0 +pluggy>=0.6.0 +py>=1.10.0 +pyhocon>=0.3.42 +pyparsing>=2.2.0 +sqlalchemy>=1.3.6 +wheel>=0.31.1 +pytz>=2018.4 +statsd>=3.2.1 +retrying>=1.3.3 +unicodecsv>=0.14.1,<1.0 +httplib2>=0.18.0 +text-unidecode>=1.3 +Jinja2>=2.10.0,<4 +pandas>=0.21.0,<1.5.0 +responses>=0.10.6 +jsonref==0.2 + +amundsen-common>=0.16.0 +amundsen-rds==0.0.8 diff --git a/databuilder/setup.cfg b/databuilder/setup.cfg new file mode 100644 index 0000000000..ebd33b6021 --- /dev/null +++ b/databuilder/setup.cfg @@ -0,0 +1,64 @@ +[flake8] +format = pylint +exclude = + CVS, + .svc, + .bzr, + .hg, + .git, + __pycache__, + venv, + .venv, + build, + databuilder/sql_parser/usage/presto/antlr_generated +max-complexity = 10 +max-line-length = 120 +ignore = W504 + +[pep8] +max-line-length = 120 + +[tool:pytest] +addopts = + -rs + --cov=databuilder + --cov-fail-under=70 + --cov-report=term-missing:skip-covered + --cov-report=xml + --cov-report=html + -vvv + +[coverage:run] +branch = True +omit = */antlr_generated/* +concurrency=multiprocessing + +[coverage:xml] +output = build/coverage.xml + +[coverage:html] +directory = build/coverage_html + +[mypy] +python_version = 3.8 +disallow_untyped_defs = True +ignore_missing_imports = True +exclude = example + +[isort] +profile = django +line_length = 120 +force_grid_wrap = 3 +combine_star = true +combine_as_imports = true +remove_redundant_aliases = true +color_output = true +skip_glob = [] + +[semantic_release] +version_variable = "./setup.py:__version__" +upload_to_pypi = true +upload_to_release = true +commit_subject = New release for {version} +commit_message = Signed-off-by: github-actions +commit_author = github-actions diff --git a/databuilder/setup.py b/databuilder/setup.py new file mode 100644 index 0000000000..805380cb92 --- /dev/null +++ b/databuilder/setup.py @@ -0,0 +1,147 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os + +from setuptools import find_packages, setup + +__version__ = '7.4.6' + +requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'requirements.txt') +with open(requirements_path, 'r') as requirements_file: + requirements = requirements_file.readlines() + +requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'requirements-dev.txt') +with open(requirements_path, 'r') as requirements_file: + requirements_dev = requirements_file.readlines() + +kafka = ['confluent-kafka==2.3.0'] + +cassandra = ['cassandra-driver==3.20.1'] + +glue = ['boto3==1.17.23'] + +snowflake = [ + 'snowflake-connector-python', + 'snowflake-sqlalchemy' +] + +athena = ['PyAthena[SQLAlchemy]>=1.0.0, <2.0.0'] + +# Python API client for google +# License: Apache Software License +# Upstream url: https://github.com/googleapis/google-api-python-client +bigquery = [ + 'google-api-python-client>=1.6.0, <2.0.0dev', + 'google-auth-httplib2>=0.0.1', + 'google-auth>=1.16.0, <3.0.0dev' +] + +jsonpath = ['jsonpath_rw==1.4.0'] + +db2 = [ + 'ibm_db>=3.0.1', + 'ibm-db-sa-py3>=0.3.1-1' +] + +dremio = [ + 'pyodbc==4.0.30' +] + +druid = [ + 'pydruid' +] + +spark = [ + 'pyspark == 3.0.1' +] + +neptune = [ + 'amundsen-gremlin>=0.0.9', + 'Flask==1.0.2', + 'gremlinpython==3.4.3', + 'requests-aws4auth==1.1.0', + 'typing-extensions==4.1.0', + 'overrides==2.5', + 'boto3==1.17.23' +] + +feast = [ + 'feast==0.17.0', + 'fastapi!=0.76.*', + 'protobuf<=3.20.1' +] + +atlas = [ + 'pyatlasclient>=1.1.2', + 'apache-atlas>=0.0.11' +] + +oracle = [ + 'cx_Oracle==8.2.1' +] + +rds = [ + 'sqlalchemy>=1.3.6', + 'mysqlclient>=1.3.6,<3' +] + +salesforce = [ + 'simple-salesforce>=1.11.2' +] + +teradata = [ + 'teradatasqlalchemy==17.0.0.0' +] + +schema_registry = [ + 'python-schema-registry-client==2.4.0' +] + +all_deps = requirements + requirements_dev + kafka + cassandra + glue + snowflake + athena + \ + bigquery + jsonpath + db2 + dremio + druid + spark + feast + neptune + rds \ + + atlas + salesforce + oracle + teradata + schema_registry + +setup( + name='amundsen-databuilder', + version=__version__, + description='Amundsen Data builder', + url='https://www.github.com/amundsen-io/amundsen/tree/main/databuilder', + maintainer='Amundsen TSC', + maintainer_email='amundsen-tsc@lists.lfai.foundation', + packages=find_packages(exclude=['tests*']), + include_package_data=True, + dependency_links=[], + install_requires=requirements, + python_requires='>=3.8', + extras_require={ + 'all': all_deps, + 'dev': requirements_dev, + 'kafka': kafka, # To use with Kafka source extractor + 'cassandra': cassandra, + 'glue': glue, + 'snowflake': snowflake, + 'athena': athena, + 'bigquery': bigquery, + 'jsonpath': jsonpath, + 'db2': db2, + 'dremio': dremio, + 'druid': druid, + 'neptune': neptune, + 'delta': spark, + 'feast': feast, + 'atlas': atlas, + 'rds': rds, + 'salesforce': salesforce, + 'oracle': oracle, + 'teradata': teradata, + 'schema_registry': schema_registry, + }, + classifiers=[ + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + ], +) diff --git a/databuilder/tests/__init__.py b/databuilder/tests/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/integration/test_chained_trainsformers_task.py b/databuilder/tests/integration/test_chained_trainsformers_task.py new file mode 100644 index 0000000000..d2a787fada --- /dev/null +++ b/databuilder/tests/integration/test_chained_trainsformers_task.py @@ -0,0 +1,156 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import ( + Any, Iterable, List, Optional, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.job.job import DefaultJob +from databuilder.loader.base_loader import Loader +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_owner import TableOwner +from databuilder.task.task import DefaultTask +from databuilder.transformer.base_transformer import ( + ChainedTransformer, NoopTransformer, Transformer, +) + +TEST_DATA = [ + TableMetadata( + database="db1", schema="schema1", name="table1", cluster="prod", description="" + ), + TableMetadata( + database="db2", schema="schema2", name="table2", cluster="prod", description="" + ), +] + +EXPECTED_OWNERS = [ + TableOwner( + db_name="db1", + cluster="prod", + schema="schema1", + table_name="table1", + owners=["foo", "bar"], + ), + TableOwner( + db_name="db2", + cluster="prod", + schema="schema2", + table_name="table2", + owners=["foo", "bar"], + ), +] + + +class TestChainedTransformerTask(unittest.TestCase): + def test_multi_yield_task(self) -> None: + """ Test that MultiYieldTask is able to unpack a transformer which yields multiple nodes """ + + result = _run_transformer(AddFakeOwnerTransformer()) + + expected = [TEST_DATA[0], EXPECTED_OWNERS[0], TEST_DATA[1], EXPECTED_OWNERS[1]] + + self.assertEqual(repr(result), repr(expected)) + + def test_multi_yield_chained_transformer(self) -> None: + """ + Test that MultiYieldChainedTransformer is able handle both: + - transformers which yield multiple nodes + - transformers which transform single nodes + """ + + transformer = ChainedTransformer( + [AddFakeOwnerTransformer(), NoopTransformer(), DuplicateTransformer()] + ) + + result = _run_transformer(transformer) + + expected = [ + TEST_DATA[0], + TEST_DATA[0], + EXPECTED_OWNERS[0], + EXPECTED_OWNERS[0], + TEST_DATA[1], + TEST_DATA[1], + EXPECTED_OWNERS[1], + EXPECTED_OWNERS[1], + ] + + self.assertEqual(repr(result), repr(expected)) + + +class AddFakeOwnerTransformer(Transformer): + """ A transformer which yields the input record, and also a TableOwner """ + + def init(self, conf: ConfigTree) -> None: + pass + + def get_scope(self) -> str: + return "transformer.fake_owner" + + def transform(self, record: Any) -> Iterable[Any]: + yield record + if isinstance(record, TableMetadata): + yield TableOwner( + db_name=record.database, + schema=record.schema, + table_name=record.name, + cluster=record.cluster, + owners=["foo", "bar"], + ) + + +class DuplicateTransformer(Transformer): + """ A transformer which yields the input record twice""" + + def init(self, conf: ConfigTree) -> None: + pass + + def get_scope(self) -> str: + return "transformer.duplicate" + + def transform(self, record: Any) -> Iterable[Any]: + yield record + yield record + + +class ListExtractor(Extractor): + """ An extractor which yields a list of records """ + + def init(self, conf: ConfigTree) -> None: + self.items = conf.get("items") + + def extract(self) -> Optional[Any]: + try: + return self.items.pop(0) + except IndexError: + return None + + def get_scope(self) -> str: + return "extractor.test" + + +class ListLoader(Loader): + """ A loader which appends all records to a list """ + + def init(self, conf: ConfigTree) -> None: + self.loaded: List[Any] = [] + + def load(self, record: Any) -> None: + self.loaded.append(record) + + +def _run_transformer(transformer: Transformer) -> List[Any]: + job_config = ConfigFactory.from_dict({"extractor.test.items": TEST_DATA}) + + loader = ListLoader() + task = DefaultTask( + extractor=ListExtractor(), transformer=transformer, loader=loader + ) + job = DefaultJob(conf=job_config, task=task) + + job.launch() + return loader.loaded diff --git a/databuilder/tests/unit/__init__.py b/databuilder/tests/unit/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/callback/__init__.py b/databuilder/tests/unit/callback/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/callback/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/callback/test_call_back.py b/databuilder/tests/unit/callback/test_call_back.py new file mode 100644 index 0000000000..ffa2619f64 --- /dev/null +++ b/databuilder/tests/unit/callback/test_call_back.py @@ -0,0 +1,57 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import List + +from mock import MagicMock + +from databuilder.callback.call_back import Callback, notify_callbacks + + +class TestCallBack(unittest.TestCase): + + def test_success_notify(self) -> None: + callback1 = MagicMock() + callback2 = MagicMock() + callbacks: List[Callback] = [callback1, callback2] + + notify_callbacks(callbacks, is_success=True) + + self.assertTrue(callback1.on_success.called) + self.assertTrue(not callback1.on_failure.called) + self.assertTrue(callback2.on_success.called) + self.assertTrue(not callback2.on_failure.called) + + def test_failure_notify(self) -> None: + callback1 = MagicMock() + callback2 = MagicMock() + callbacks: List[Callback] = [callback1, callback2] + + notify_callbacks(callbacks, is_success=False) + + self.assertTrue(not callback1.on_success.called) + self.assertTrue(callback1.on_failure.called) + self.assertTrue(not callback2.on_success.called) + self.assertTrue(callback2.on_failure.called) + + def test_notify_failure(self) -> None: + callback1 = MagicMock() + callback2 = MagicMock() + callback2.on_success.side_effect = Exception('Boom') + callback3 = MagicMock() + callbacks: List[Callback] = [callback1, callback2, callback3] + + try: + notify_callbacks(callbacks, is_success=True) + self.assertTrue(False) + except Exception: + self.assertTrue(True) + + self.assertTrue(callback1.on_success.called) + self.assertTrue(callback2.on_success.called) + self.assertTrue(callback3.on_success.called) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/__init__.py b/databuilder/tests/unit/extractor/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/__init__.py b/databuilder/tests/unit/extractor/dashboard/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/apache_superset/__init__.py b/databuilder/tests/unit/extractor/dashboard/apache_superset/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/apache_superset/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_chart_extractor.py b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_chart_extractor.py new file mode 100644 index 0000000000..595ca6dae0 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_chart_extractor.py @@ -0,0 +1,96 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import unittest +from typing import Any + +from mock import MagicMock, Mock +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.apache_superset.apache_superset_chart_extractor import ApacheSupersetChartExtractor +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_query import DashboardQuery + +dashboard_details_response = { + 'dashboards': [ + { + '__Dashboard__': { + 'slices': [ + { + '__Slice__': { + 'id': 1, + 'slice_name': 'chart_1', + 'viz_type': 'pie_chart', + 'chart_url': '/chart_1' + } + }, + { + '__Slice__': { + 'id': 2, + 'slice_name': 'chart_2', + 'viz_type': 'table', + 'chart_url': '/chart_2' + } + } + ] + } + } + ] +} + + +class TestApacheSupersetChartExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.apache_superset.dashboard_group_id': '1', + 'extractor.apache_superset.dashboard_group_name': 'dashboard group', + 'extractor.apache_superset.dashboard_group_description': 'dashboard group description', + 'extractor.apache_superset.cluster': 'gold', + 'extractor.apache_superset.apache_superset_security_settings_dict': dict(username='admin', + password='admin', + provider='db') + }) + + self.config = config + + def _get_extractor(self) -> Any: + extractor = self._extractor_class() + extractor.authenticate = MagicMock() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + return extractor + + @property + def _extractor_class(self) -> Any: + return ApacheSupersetChartExtractor + + def test_extractor(self) -> None: + extractor = self._get_extractor() + + extractor.execute_query = Mock(side_effect=[{'ids': [1]}, {'ids': []}, dashboard_details_response]) + + record = extractor.extract() + + self.assertIsInstance(record, DashboardQuery) + self.assertEqual(record._query_name, 'default') + self.assertEqual(record._query_id, '-1') + self.assertEqual(record._product, 'superset') + self.assertEqual(record._cluster, 'gold') + + record = extractor.extract() + + self.assertIsInstance(record, DashboardChart) + self.assertEqual(record._query_id, '-1') + self.assertEqual(record._chart_id, '1') + self.assertEqual(record._chart_name, 'chart_1') + self.assertEqual(record._chart_type, 'pie_chart') + self.assertEqual(record._chart_url, '') + + record = extractor.extract() + + self.assertIsInstance(record, DashboardChart) + self.assertEqual(record._query_id, '-1') + self.assertEqual(record._chart_id, '2') + self.assertEqual(record._chart_name, 'chart_2') + self.assertEqual(record._chart_type, 'table') + self.assertEqual(record._chart_url, '') diff --git a/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_metadata_extractor.py b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_metadata_extractor.py new file mode 100644 index 0000000000..5ea178af09 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_metadata_extractor.py @@ -0,0 +1,78 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import unittest +from typing import Any + +from mock import MagicMock, Mock +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.apache_superset.apache_superset_metadata_extractor import ( + ApacheSupersetMetadataExtractor, +) +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata + +dashboard_data_response = { + 'result': { + 'id': 2, + 'changed_on': '2021-05-14 08:41:05.934134', + 'dashboard_title': 'dashboard name', + 'url': '/2', + 'published': 'true' + } +} + + +class TestApacheSupersetMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.apache_superset.dashboard_group_id': '1', + 'extractor.apache_superset.dashboard_group_name': 'dashboard group', + 'extractor.apache_superset.dashboard_group_description': 'dashboard group description', + 'extractor.apache_superset.cluster': 'gold', + 'extractor.apache_superset.apache_superset_security_settings_dict': dict(username='admin', + password='admin', + provider='db') + }) + self.config = config + + def _get_extractor(self) -> Any: + extractor = self._extractor_class() + extractor.authenticate = MagicMock() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + return extractor + + @property + def _extractor_class(self) -> Any: + return ApacheSupersetMetadataExtractor + + def test_extractor(self) -> None: + extractor = self._get_extractor() + + extractor.execute_query = Mock(side_effect=[{'ids': [2]}, {'ids': []}, dashboard_data_response]) + + record = extractor.extract() + + self.assertIsInstance(record, DashboardMetadata) + self.assertEqual(record.dashboard_group, 'dashboard group') + self.assertEqual(record.dashboard_name, 'dashboard name') + self.assertEqual(record.description, '') + self.assertEqual(record.cluster, 'gold') + self.assertEqual(record.product, 'superset') + self.assertEqual(record.dashboard_group_id, '1') + self.assertEqual(record.dashboard_id, '2') + self.assertEqual(record.dashboard_group_description, 'dashboard group description') + self.assertEqual(record.created_timestamp, 0) + self.assertEqual(record.dashboard_group_url, 'http://localhost:8088') + self.assertEqual(record.dashboard_url, 'http://localhost:8088/2') + + record = extractor.extract() + + self.assertIsInstance(record, DashboardLastModifiedTimestamp) + self.assertEqual(record._dashboard_group_id, '1') + self.assertEqual(record._dashboard_id, '2') + self.assertEqual(record._last_modified_timestamp, 1620981665) + self.assertEqual(record._product, 'superset') + self.assertEqual(record._cluster, 'gold') diff --git a/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_table_extractor.py b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_table_extractor.py new file mode 100644 index 0000000000..406c50cfcc --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/apache_superset/test_apache_superset_table_extractor.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import unittest +from typing import Any + +from mock import MagicMock, Mock +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.apache_superset.apache_superset_table_extractor import ApacheSupersetTableExtractor +from databuilder.models.dashboard.dashboard_table import DashboardTable + +dataset_data_response_1 = { + 'result': { + 'sql': None, + 'table_name': 'table_name', + 'database': { + 'id': 1 + } + } +} + +dataset_objects_data_response_1 = { + 'dashboards': { + 'result': [ + { + 'id': 2 + } + ] + } +} + +database_data_response_1 = { + 'result': { + 'sqlalchemy_uri': 'postgresql://localhost:5432/db_name' + } +} + +dataset_data_response_2 = { + 'result': { + 'sql': None, + 'table_name': 'table_name_2', + 'database': { + 'id': 3 + } + } +} + +dataset_objects_data_response_2 = { + 'dashboards': { + 'result': [ + { + 'id': 2 + } + ] + } +} + +database_data_response_2 = { + 'result': { + 'sqlalchemy_uri': 'postgresql://localhost:5432/db_name_2' + } +} + + +class TestApacheSupersetTableExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.apache_superset.dashboard_group_id': '1', + 'extractor.apache_superset.dashboard_group_name': 'dashboard group', + 'extractor.apache_superset.dashboard_group_description': 'dashboard group description', + 'extractor.apache_superset.cluster': 'gold', + 'extractor.apache_superset.apache_superset_security_settings_dict': dict(username='admin', + password='admin', + provider='db') + }) + + self.config = config + + def _get_extractor(self) -> Any: + extractor = self._extractor_class() + extractor.authenticate = MagicMock() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + return extractor + + @property + def _extractor_class(self) -> Any: + return ApacheSupersetTableExtractor + + def test_extractor(self) -> None: + extractor = self._get_extractor() + + extractor.execute_query = Mock(side_effect=[{'ids': [2, 3]}, {'ids': []}, + dataset_data_response_1, dataset_objects_data_response_1, + dataset_data_response_2, dataset_objects_data_response_2, + database_data_response_1, database_data_response_2]) + + record = extractor.extract() + + self.assertIsInstance(record, DashboardTable) + self.assertEquals(record._dashboard_id, '2') + self.assertSetEqual(record._table_ids, + {'postgres://gold.db_name/table_name', 'postgres://gold.db_name_2/table_name_2'}) diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/__init__.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_charts_batch_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_charts_batch_extractor.py new file mode 100644 index 0000000000..73bdbfaec7 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_charts_batch_extractor.py @@ -0,0 +1,64 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_charts_batch_extractor import ( + ModeDashboardChartsBatchExtractor, +) +from databuilder.models.dashboard.dashboard_chart import DashboardChart + + +class TestModeDashboardChartsBatchExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_chart_batch.organization': 'amundsen', + 'extractor.mode_dashboard_chart_batch.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_dashboard_chart_extractor_empty_record(self) -> None: + extractor = ModeDashboardChartsBatchExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = {'charts': []} + record = extractor.extract() + self.assertIsNone(record) + + def test_dashboard_chart_extractor_actual_record(self) -> None: + extractor = ModeDashboardChartsBatchExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'charts': [ + { + 'space_token': 'ggg', + 'report_token': 'ddd', + 'query_token': 'yyy', + 'token': 'xxx', + 'chart_title': 'some chart', + 'chart_type': 'bigNumber' + } + ] + } + + record = extractor.extract() + self.assertIsInstance(record, DashboardChart) + self.assertEqual(record._dashboard_group_id, 'ggg') + self.assertEqual(record._dashboard_id, 'ddd') + self.assertEqual(record._query_id, 'yyy') + self.assertEqual(record._chart_id, 'xxx') + self.assertEqual(record._chart_name, 'some chart') + self.assertEqual(record._chart_type, 'bigNumber') + self.assertEqual(record._product, 'mode') + self.assertEqual(record._cluster, 'gold') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_executions_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_executions_extractor.py new file mode 100644 index 0000000000..f026819b12 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_executions_extractor.py @@ -0,0 +1,52 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_executions_extractor import ( + ModeDashboardExecutionsExtractor, +) +from databuilder.models.dashboard.dashboard_execution import DashboardExecution + + +class TestModeDashboardExecutionsExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_execution.organization': 'amundsen', + 'extractor.mode_dashboard_execution.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardExecutionsExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'space_token': 'ggg', + 'token': 'ddd', + 'last_run_at': '2021-02-05T21:20:09.019Z', + 'last_run_state': 'failed', + } + ] + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardExecution) + self.assertEqual(record._dashboard_group_id, 'ggg') + self.assertEqual(record._dashboard_id, 'ddd') + self.assertEqual(record._execution_timestamp, 1612560009) + self.assertEqual(record._execution_state, 'failed') + self.assertEqual(record._product, 'mode') + self.assertEqual(record._cluster, 'gold') + self.assertEqual(record._execution_id, '_last_execution') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_extractor.py new file mode 100644 index 0000000000..3fd9230314 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_extractor.py @@ -0,0 +1,125 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_extractor import ModeDashboardExtractor +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata + + +class TestModeDashboardExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard.organization': 'amundsen', + 'extractor.mode_dashboard.mode_bearer_token': 'amundsen_bearer_token', + 'extractor.mode_dashboard.dashboard_group_ids_to_skip': ['ggg_to_skip'], + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request, \ + patch('databuilder.rest_api.query_merger.QueryMerger._compute_query_result') as mock_query_result: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'token': 'ddd', + 'name': 'dashboard name', + 'description': 'dashboard description', + 'created_at': '2021-02-05T21:20:09.019Z', + 'space_token': 'ggg', + } + ] + } + mock_query_result.return_value = { + 'ggg': { + 'dashboard_group_id': 'ggg', + 'dashboard_group': 'dashboard group name', + 'dashboard_group_description': 'dashboard group description', + } + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardMetadata) + self.assertEqual(record.dashboard_group, 'dashboard group name') + self.assertEqual(record.dashboard_name, 'dashboard name') + self.assertEqual(record.description, 'dashboard description') + self.assertEqual(record.cluster, 'gold') + self.assertEqual(record.product, 'mode') + self.assertEqual(record.dashboard_group_id, 'ggg') + self.assertEqual(record.dashboard_id, 'ddd') + self.assertEqual(record.dashboard_group_description, 'dashboard group description') + self.assertEqual(record.created_timestamp, 1612560009) + self.assertEqual(record.dashboard_group_url, 'https://app.mode.com/amundsen/spaces/ggg') + self.assertEqual(record.dashboard_url, 'https://app.mode.com/amundsen/reports/ddd') + + def test_extractor_skip_record(self) -> None: + extractor = ModeDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request, \ + patch('databuilder.rest_api.query_merger.QueryMerger._compute_query_result') as mock_query_result: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'token': 'ddd', + 'name': 'dashboard name', + 'description': 'dashboard description', + 'created_at': '2021-02-05T21:20:09.019Z', + 'space_token': 'ggg', + }, + { + 'token': 'ddd_2', + 'name': 'dashboard name 2', + 'description': 'dashboard description 2', + 'created_at': '2021-02-05T21:20:09.019Z', + 'space_token': 'ggg_to_skip', + }, + { + 'token': 'ddd_3', + 'name': 'dashboard name 3', + 'description': 'dashboard description 3', + 'created_at': '2021-02-05T21:20:09.019Z', + 'space_token': 'ggg_not_skip', + }, + ] + } + mock_query_result.return_value = { + 'ggg': { + 'dashboard_group_id': 'ggg', + 'dashboard_group': 'dashboard group name', + 'dashboard_group_description': 'dashboard group description', + }, + 'ggg_to_skip': { + 'dashboard_group_id': 'ggg_to_skip', + 'dashboard_group': 'dashboard group name to skip', + 'dashboard_group_description': 'dashboard group description to skip', + }, + 'ggg_not_skip': { + 'dashboard_group_id': 'ggg_not_skip', + 'dashboard_group': 'dashboard group name not skip', + 'dashboard_group_description': 'dashboard group description not skip', + } + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardMetadata) + self.assertEqual(record.dashboard_group_id, 'ggg') + self.assertEqual(record.dashboard_id, 'ddd') + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardMetadata) + self.assertEqual(record.dashboard_group_id, 'ggg_not_skip') + self.assertEqual(record.dashboard_id, 'ddd_3') + + self.assertIsNone(extractor.extract()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_modified_timestamp_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_modified_timestamp_extractor.py new file mode 100644 index 0000000000..a0574bb622 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_modified_timestamp_extractor.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_last_modified_timestamp_extractor import ( + ModeDashboardLastModifiedTimestampExtractor, +) +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp + + +class TestModeDashboardLastModifiedTimestampExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_last_modified_timestamp_execution.organization': 'amundsen', + 'extractor.mode_dashboard_last_modified_timestamp_execution.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardLastModifiedTimestampExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'space_token': 'ggg', + 'token': 'ddd', + 'edited_at': '2021-02-05T21:20:09.019Z', + } + ] + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardLastModifiedTimestamp) + self.assertEqual(record._dashboard_group_id, 'ggg') + self.assertEqual(record._dashboard_id, 'ddd') + self.assertEqual(record._last_modified_timestamp, 1612560009) + self.assertEqual(record._product, 'mode') + self.assertEqual(record._cluster, 'gold') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_successful_executions_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_successful_executions_extractor.py new file mode 100644 index 0000000000..fa1cfc21d2 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_last_successful_executions_extractor.py @@ -0,0 +1,51 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_last_successful_executions_extractor import ( + ModeDashboardLastSuccessfulExecutionExtractor, +) +from databuilder.models.dashboard.dashboard_execution import DashboardExecution + + +class TestModeDashboardLastSuccessfulExecutionExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_last_successful_execution.organization': 'amundsen', + 'extractor.mode_dashboard_last_successful_execution.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardLastSuccessfulExecutionExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'space_token': 'ggg', + 'token': 'ddd', + 'last_successfully_run_at': '2021-02-05T21:20:09.019Z', + } + ] + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardExecution) + self.assertEqual(record._dashboard_group_id, 'ggg') + self.assertEqual(record._dashboard_id, 'ddd') + self.assertEqual(record._execution_timestamp, 1612560009) + self.assertEqual(record._execution_state, 'succeeded') + self.assertEqual(record._product, 'mode') + self.assertEqual(record._cluster, 'gold') + self.assertEqual(record._execution_id, '_last_successful_execution') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_owner_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_owner_extractor.py new file mode 100644 index 0000000000..fb00da5e07 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_owner_extractor.py @@ -0,0 +1,46 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_owner_extractor import ModeDashboardOwnerExtractor +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.models.dashboard.dashboard_owner import DashboardOwner + + +class TestModeDashboardLastModifiedTimestampExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_owner.organization': 'amundsen', + 'extractor.mode_dashboard_owner.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardOwnerExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'reports': [ + { + 'space_token': 'ggg', + 'token': 'ddd', + 'creator_email': 'amundsen@abc.com', + } + ] + } + + record = extractor.extract() + self.assertIsInstance(record, DashboardOwner) + self.assertEqual(record.owner_emails, ['amundsen@abc.com']) + self.assertEqual(record.start_label, DashboardMetadata.DASHBOARD_NODE_LABEL) + self.assertEqual(record.start_key, 'mode_dashboard://gold.ggg/ddd') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_queries_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_queries_extractor.py new file mode 100644 index 0000000000..cf906ee5b7 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_queries_extractor.py @@ -0,0 +1,54 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_queries_extractor import ( + ModeDashboardQueriesExtractor, +) +from databuilder.models.dashboard.dashboard_query import DashboardQuery + + +class TestModeDashboardLastModifiedTimestampExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_query.organization': 'amundsen', + 'extractor.mode_dashboard_query.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardQueriesExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + mock_request.return_value.json.return_value = { + 'queries': [ + { + 'space_token': 'ggg', + 'report_token': 'ddd', + 'token': 'qqq', + 'name': 'this query name', + 'raw_query': 'select 1', + } + ] + } + + record = next(extractor.extract()) + self.assertIsInstance(record, DashboardQuery) + self.assertEqual(record._dashboard_group_id, 'ggg') + self.assertEqual(record._dashboard_id, 'ddd') + self.assertEqual(record._query_id, 'qqq') + self.assertEqual(record._query_name, 'this query name') + self.assertEqual(record._query_text, 'select 1') + self.assertEqual(record._url, 'https://app.mode.com/amundsen/reports/ddd/queries/qqq') + self.assertEqual(record._product, 'mode') + self.assertEqual(record._cluster, 'gold') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_usage_extractor.py b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_usage_extractor.py new file mode 100644 index 0000000000..7990013006 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/mode_analytics/test_mode_dashboard_usage_extractor.py @@ -0,0 +1,66 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.mode_analytics.mode_dashboard_usage_extractor import ModeDashboardUsageExtractor + + +class TestModeDashboardUsageExtractor(unittest.TestCase): + def setUp(self) -> None: + config = ConfigFactory.from_dict({ + 'extractor.mode_dashboard_usage.organization': 'amundsen', + 'extractor.mode_dashboard_usage.mode_bearer_token': 'amundsen_bearer_token', + }) + self.config = config + + def test_extractor_extract_record(self) -> None: + extractor = ModeDashboardUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + with patch('databuilder.rest_api.rest_api_query.RestApiQuery._send_request') as mock_request: + with patch('databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query.ModePaginatedRestApiQuery._post_process'): # noqa + mock_request.return_value.json.side_effect = [ + { + 'report_stats': [ + { + 'report_token': 'ddd', + 'view_count': 20, + } + ] + }, + { + 'reports': [ + { + 'token': 'ddd', + 'space_token': 'ggg', + } + ] + }, + { + 'spaces': [ + { + 'token': 'ggg', + 'name': 'dashboard group name', + 'description': 'dashboard group description' + } + ] + }, + ] + + record = extractor.extract() + self.assertEqual(record['organization'], 'amundsen') + self.assertEqual(record['dashboard_id'], 'ddd') + self.assertEqual(record['accumulated_view_count'], 20) + self.assertEqual(record['dashboard_group_id'], 'ggg') + self.assertEqual(record['dashboard_group'], 'dashboard group name') + self.assertEqual(record['dashboard_group_description'], 'dashboard group description') + self.assertEqual(record['product'], 'mode') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/redash/__init__.py b/databuilder/tests/unit/extractor/dashboard/redash/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/redash/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_extractor.py b/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_extractor.py new file mode 100644 index 0000000000..9ca70bf3cc --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_extractor.py @@ -0,0 +1,284 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import ( + Any, Dict, List, +) + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.redash.redash_dashboard_extractor import ( + RedashDashboardExtractor, TableRelationData, +) +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.dashboard.dashboard_owner import DashboardOwner +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.models.dashboard.dashboard_table import DashboardTable + +logging.basicConfig(level=logging.INFO) + + +def dummy_tables(*args: Any) -> List[TableRelationData]: + return [TableRelationData('some_db', 'prod', 'public', 'users')] + + +class MockApiResponse: + def __init__(self, data: Any) -> None: + self.json_data = data + self.status_code = 200 + + def json(self) -> Any: + return self.json_data + + def raise_for_status(self) -> None: + pass + + +class TestRedashDashboardExtractor(unittest.TestCase): + def test_table_relation_data(self) -> None: + tr = TableRelationData('db', 'cluster', 'schema', 'tbl') + self.assertEqual(tr.key, 'db://cluster.schema/tbl') + + def test_with_one_dashboard(self) -> None: + def mock_api_get(url: str, *args: Any, **kwargs: Any) -> MockApiResponse: + if '1000' in url: + return MockApiResponse({ + 'id': 1000, + 'widgets': [ + { + 'visualization': { + 'query': { + 'data_source_id': 1, + 'id': 1234, + 'name': 'Test Query', + 'query': 'SELECT id FROM users' + }, + 'id': 12345, + 'name': 'test_widget', + 'type': 'CHART', + }, + 'options': {} + } + ] + }) + + return MockApiResponse({ + 'page': 1, + 'count': 1, + 'page_size': 50, + 'results': [ + { + 'id': 1000, + 'name': 'Test Dash', + 'slug': 'test-dash', + 'created_at': '2020-01-01T00:00:00.000Z', + 'updated_at': '2020-01-02T00:00:00.000Z', + 'is_archived': False, + 'is_draft': False, + 'user': {'email': 'asdf@example.com'} + } + ] + }) + + redash_base_url = 'https://redash.example.com' + config = ConfigFactory.from_dict({ + 'extractor.redash_dashboard.redash_base_url': redash_base_url, + 'extractor.redash_dashboard.api_base_url': redash_base_url, # probably not but doesn't matter + 'extractor.redash_dashboard.api_key': 'abc123', + 'extractor.redash_dashboard.table_parser': + 'tests.unit.extractor.dashboard.redash.test_redash_dashboard_extractor.dummy_tables' + }) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + mock_get.side_effect = mock_api_get + + extractor = RedashDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + # DashboardMetadata + record = extractor.extract() + self.assertEqual(record.dashboard_id, '1000') + self.assertEqual(record.dashboard_name, 'Test Dash') + self.assertEqual(record.dashboard_group_id, RedashDashboardExtractor.DASHBOARD_GROUP_ID) + self.assertEqual(record.dashboard_group, RedashDashboardExtractor.DASHBOARD_GROUP_NAME) + self.assertEqual(record.product, RedashDashboardExtractor.PRODUCT) + self.assertEqual(record.cluster, RedashDashboardExtractor.DEFAULT_CLUSTER) + self.assertEqual(record.created_timestamp, 1577836800) + self.assertTrue(redash_base_url in record.dashboard_url) + self.assertTrue('1000' in record.dashboard_url) + + # DashboardLastModified + record = extractor.extract() + identity: Dict[str, Any] = { + 'dashboard_id': '1000', + 'dashboard_group_id': RedashDashboardExtractor.DASHBOARD_GROUP_ID, + 'product': RedashDashboardExtractor.PRODUCT, + 'cluster': u'prod' + } + expected_timestamp = DashboardLastModifiedTimestamp( + last_modified_timestamp=1577923200, + **identity + ) + self.assertEqual(record.__repr__(), expected_timestamp.__repr__()) + + # DashboardOwner + record = extractor.extract() + expected_owner = DashboardOwner(email='asdf@example.com', **identity) + self.assertEqual(record.__repr__(), expected_owner.__repr__()) + + # DashboardQuery + record = extractor.extract() + expected_query = DashboardQuery( + query_id='1234', + query_name='Test Query', + url=f'{redash_base_url}/queries/1234', + query_text='SELECT id FROM users', + **identity + ) + self.assertEqual(record.__repr__(), expected_query.__repr__()) + + # DashboardChart + record = extractor.extract() + expected_chart = DashboardChart( + query_id='1234', + chart_id='12345', + chart_name='test_widget', + chart_type='CHART', + **identity + ) + self.assertEqual(record.__repr__(), expected_chart.__repr__()) + + # DashboardTable + record = extractor.extract() + expected_table = DashboardTable( + table_ids=[TableRelationData('some_db', 'prod', 'public', 'users').key], + **identity + ) + self.assertEqual(record.__repr__(), expected_table.__repr__()) + + def test_with_verion_8(self) -> None: + def mock_api_get(url: str, *args: Any, **kwargs: Any) -> MockApiResponse: + if 'test-dash' in url: + return MockApiResponse({ + 'id': 1000, + 'widgets': [ + { + 'visualization': { + 'query': { + 'data_source_id': 1, + 'id': 1234, + 'name': 'Test Query', + 'query': 'SELECT id FROM users' + }, + 'id': 12345, + 'name': 'test_widget', + 'type': 'CHART', + }, + 'options': {} + } + ] + }) + + return MockApiResponse({ + 'page': 1, + 'count': 1, + 'page_size': 50, + 'results': [ + { + 'id': 1000, + 'name': 'Test Dash', + 'slug': 'test-dash', + 'created_at': '2020-01-01T00:00:00.000Z', + 'updated_at': '2020-01-02T00:00:00.000Z', + 'is_archived': False, + 'is_draft': False, + 'user': {'email': 'asdf@example.com'} + } + ] + }) + + redash_base_url = 'https://redash.example.com' + config = ConfigFactory.from_dict({ + 'extractor.redash_dashboard.redash_base_url': redash_base_url, + 'extractor.redash_dashboard.api_base_url': redash_base_url, # probably not but doesn't matter + 'extractor.redash_dashboard.api_key': 'abc123', + 'extractor.redash_dashboard.table_parser': + 'tests.unit.extractor.dashboard.redash.test_redash_dashboard_extractor.dummy_tables', + 'extractor.redash_dashboard.redash_version': 8 + }) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + mock_get.side_effect = mock_api_get + + extractor = RedashDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + # DashboardMetadata + record = extractor.extract() + self.assertEqual(record.dashboard_id, '1000') + self.assertEqual(record.dashboard_name, 'Test Dash') + self.assertEqual(record.dashboard_group_id, RedashDashboardExtractor.DASHBOARD_GROUP_ID) + self.assertEqual(record.dashboard_group, RedashDashboardExtractor.DASHBOARD_GROUP_NAME) + self.assertEqual(record.product, RedashDashboardExtractor.PRODUCT) + self.assertEqual(record.cluster, RedashDashboardExtractor.DEFAULT_CLUSTER) + self.assertEqual(record.created_timestamp, 1577836800) + self.assertTrue(redash_base_url in record.dashboard_url) + self.assertTrue('test-dash' in record.dashboard_url) + + # DashboardLastModified + record = extractor.extract() + identity: Dict[str, Any] = { + 'dashboard_id': '1000', + 'dashboard_group_id': RedashDashboardExtractor.DASHBOARD_GROUP_ID, + 'product': RedashDashboardExtractor.PRODUCT, + 'cluster': u'prod' + } + expected_timestamp = DashboardLastModifiedTimestamp( + last_modified_timestamp=1577923200, + **identity + ) + self.assertEqual(record.__repr__(), expected_timestamp.__repr__()) + + # DashboardOwner + record = extractor.extract() + expected_owner = DashboardOwner(email='asdf@example.com', **identity) + self.assertEqual(record.__repr__(), expected_owner.__repr__()) + + # DashboardQuery + record = extractor.extract() + expected_query = DashboardQuery( + query_id='1234', + query_name='Test Query', + url=f'{redash_base_url}/queries/1234', + query_text='SELECT id FROM users', + **identity + ) + self.assertEqual(record.__repr__(), expected_query.__repr__()) + + # DashboardChart + record = extractor.extract() + expected_chart = DashboardChart( + query_id='1234', + chart_id='12345', + chart_name='test_widget', + chart_type='CHART', + **identity + ) + self.assertEqual(record.__repr__(), expected_chart.__repr__()) + + # DashboardTable + record = extractor.extract() + expected_table = DashboardTable( + table_ids=[TableRelationData('some_db', 'prod', 'public', 'users').key], + **identity + ) + self.assertEqual(record.__repr__(), expected_table.__repr__()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_utils.py b/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_utils.py new file mode 100644 index 0000000000..4df500e577 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/redash/test_redash_dashboard_utils.py @@ -0,0 +1,212 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import random +import unittest +from typing import ( + Any, Dict, List, +) + +from mock import patch + +from databuilder.extractor.dashboard.redash.redash_dashboard_utils import ( + RedashPaginatedRestApiQuery, generate_dashboard_description, get_auth_headers, get_text_widgets, + get_visualization_widgets, sort_widgets, +) +from databuilder.rest_api.base_rest_api_query import EmptyRestApiQuerySeed + +logging.basicConfig(level=logging.INFO) + + +class TestRedashDashboardUtils(unittest.TestCase): + def test_sort_widgets(self) -> None: + widgets = [ + { + 'text': 'a', + 'options': {} + }, + { + 'text': 'b', + 'options': {'position': {'row': 1, 'col': 1}} + }, + { + 'text': 'c', + 'options': {'position': {'row': 1, 'col': 2}} + }, + { + 'text': 'd', + 'options': {'position': {'row': 2, 'col': 1}} + } + ] + random.shuffle(widgets) + sorted_widgets = sort_widgets(widgets) + self.assertListEqual([widget['text'] for widget in sorted_widgets], ['a', 'b', 'c', 'd']) + + def test_widget_filters(self) -> None: + widgets: List[Dict[str, Any]] = [ + {'text': 'asdf', 'options': {'ex': 1}}, + {'text': 'asdf', 'options': {'ex': 2}}, + {'visualization': {}, 'options': {'ex': 1}}, + {'visualization': {}, 'options': {'ex': 2}}, + {'visualization': {}, 'options': {'ex': 3}} + ] + self.assertEqual(len(get_text_widgets(widgets)), 2) + self.assertEqual(len(get_visualization_widgets(widgets)), 3) + + def test_text_widget_props(self) -> None: + widget_data = { + 'text': 'asdf' + } + widget = get_text_widgets([widget_data])[0] + self.assertEqual(widget.text, 'asdf') + + def test_visualization_widget_props(self) -> None: + widget_data = { + 'visualization': { + 'query': { + 'id': 123, + 'data_source_id': 1, + 'query': 'SELECT 2+2 FROM DUAL', + 'name': 'Test' + }, + 'id': 12345, + 'name': 'test_widget', + 'type': 'CHART' + } + } + widget = get_visualization_widgets([widget_data])[0] + + self.assertEqual(widget.query_id, 123) + self.assertEqual(widget.data_source_id, 1) + self.assertEqual(widget.raw_query, 'SELECT 2+2 FROM DUAL') + self.assertEqual(widget.query_name, 'Test') + self.assertEqual(widget.visualization_id, 12345) + self.assertEqual(widget.visualization_name, 'test_widget') + self.assertEqual(widget.visualization_type, 'CHART') + + def test_descriptions_from_text(self) -> None: + text_widgets = get_text_widgets([ + {'text': 'T1'}, + {'text': 'T2'} + ]) + viz_widgets = get_visualization_widgets([ + { + 'visualization': { + 'query': { + 'id': 1, + 'data_source_id': 1, + 'name': 'Q1', + 'query': 'n/a' + } + } + }, + { + 'visualization': { + 'query': { + 'id': 2, + 'data_source_id': 1, + 'name': 'Q2', + 'query': 'n/a' + } + } + } + ]) + + # both text and viz widgets + desc1 = generate_dashboard_description(text_widgets, viz_widgets) + self.assertTrue('T1' in desc1) + self.assertTrue('T2' in desc1) + self.assertTrue('Q1' not in desc1) + + # only text widgets + desc2 = generate_dashboard_description(text_widgets, []) + self.assertEqual(desc1, desc2) + + # only viz widgets + desc3 = generate_dashboard_description([], viz_widgets) + self.assertTrue('Q1' in desc3) + self.assertTrue('Q2' in desc3) + + # no widgets + desc4 = generate_dashboard_description([], []) + self.assertTrue('empty' in desc4) + + def test_descriptions_remove_duplicate(self) -> None: + viz_widgets = get_visualization_widgets([ + { + 'visualization': { + 'query': { + 'id': 1, + 'data_source_id': 1, + 'name': 'same_query_name', + 'query': 'n/a' + } + } + }, + { + 'visualization': { + 'query': { + 'id': 2, + 'data_source_id': 1, + 'name': 'same_query_name', + 'query': 'n/a' + } + } + } + ]) + desc1 = generate_dashboard_description([], viz_widgets) + self.assertEqual('A dashboard containing the following queries:\n\n- same_query_name', desc1) + + def test_auth_headers(self) -> None: + headers = get_auth_headers('testkey') + self.assertTrue('testkey' in headers['Authorization']) + + def test_paginated_rest_api_query(self) -> None: + paged_content = [ + { + 'page': 1, + 'page_size': 5, + 'count': 12, + 'results': [{'test': True}] * 5 + }, + { + 'page': 2, + 'page_size': 5, + 'count': 12, + 'results': [{'test': True}] * 5 + }, + { + 'page': 3, + 'page_size': 5, + 'count': 12, + 'results': [{'test': True}] * 2 + }, + { + 'page': 4, + 'page_size': 5, + 'count': 12, + 'results': [] + } + ] + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + # .json() is called twice (ugh), so we have to double each page + mock_get.return_value.json.side_effect = [page for page in paged_content for page in [page] * 2] + + q = RedashPaginatedRestApiQuery(query_to_join=EmptyRestApiQuerySeed(), + url='example.com', + json_path='results[*].[test]', + params={}, + field_names=['test'], + skip_no_result=True) + n_records = 0 + for record in q.execute(): + self.assertEqual(record['test'], True) + n_records += 1 + + self.assertEqual(n_records, 12) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/tableau/__init__.py b/databuilder/tests/unit/extractor/dashboard/tableau/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/tableau/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_extractor.py b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_extractor.py new file mode 100644 index 0000000000..b689752229 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_extractor.py @@ -0,0 +1,126 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.tableau.tableau_dashboard_extractor import TableauDashboardExtractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardAuth, TableauGraphQLApiExtractor, +) + +logging.basicConfig(level=logging.INFO) + + +def mock_query(*_args: Any, **_kwargs: Any) -> Dict[str, Any]: + return { + 'workbooks': [ + { + 'id': 'fake-id', + 'name': 'Test Workbook', + 'createdAt': '2020-04-08T05:32:01Z', + 'description': '', + 'projectName': 'Test Project', + 'projectVizportalUrlId': 123, + 'vizportalUrlId': 456 + }, + { + 'id': 'fake-id', + 'name': None, + 'createdAt': '2020-04-08T05:32:01Z', + 'description': '', + 'projectName': None, + 'projectVizportalUrlId': 123, + 'vizportalUrlId': 456 + } + ] + } + + +def mock_token(*_args: Any, **_kwargs: Any) -> str: + return '123-abc' + + +class TestTableauDashboardExtractor(unittest.TestCase): + + @patch.object(TableauDashboardAuth, '_authenticate', mock_token) + @patch.object(TableauGraphQLApiExtractor, 'execute_query', mock_query) + def test_dashboard_metadata_extractor(self) -> None: + + config = ConfigFactory.from_dict({ + 'extractor.tableau_dashboard_metadata.api_base_url': 'https://api_base_url', + 'extractor.tableau_dashboard_metadata.tableau_base_url': 'https://tableau_base_url', + 'extractor.tableau_dashboard_metadata.api_version': 'tableau_api_version', + 'extractor.tableau_dashboard_metadata.site_name': 'tableau_site_name', + 'extractor.tableau_dashboard_metadata.tableau_personal_access_token_name': + 'tableau_personal_access_token_name', + 'extractor.tableau_dashboard_metadata.tableau_personal_access_token_secret': + 'tableau_personal_access_token_secret', + 'extractor.tableau_dashboard_metadata.excluded_projects': [], + 'extractor.tableau_dashboard_metadata.cluster': 'tableau_dashboard_cluster', + 'extractor.tableau_dashboard_metadata.database': 'tableau_dashboard_database', + 'extractor.tableau_dashboard_metadata.transformer.timestamp_str_to_epoch.timestamp_format': + '%Y-%m-%dT%H:%M:%SZ', + + }) + + extractor = TableauDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + record = extractor.extract() + self.assertEqual(record.dashboard_id, 'Test Workbook') + self.assertEqual(record.dashboard_name, 'Test Workbook') + self.assertEqual(record.dashboard_group_id, 'Test Project') + self.assertEqual(record.dashboard_group, 'Test Project') + self.assertEqual(record.product, 'tableau') + self.assertEqual(record.cluster, 'tableau_dashboard_cluster') + self.assertEqual(record.dashboard_group_url, 'https://tableau_base_url/#/site/tableau_site_name/projects/123') + self.assertEqual(record.dashboard_url, 'https://tableau_base_url/#/site/tableau_site_name/workbooks/456/views') + self.assertEqual(record.created_timestamp, 1586323921) + + record = extractor.extract() + self.assertIsNone(record) + + # Test for Tableau single site deployment + config = ConfigFactory.from_dict({ + 'extractor.tableau_dashboard_metadata.api_base_url': 'https://api_base_url', + 'extractor.tableau_dashboard_metadata.tableau_base_url': 'https://tableau_base_url', + 'extractor.tableau_dashboard_metadata.api_version': 'tableau_api_version', + 'extractor.tableau_dashboard_metadata.site_name': '', + 'extractor.tableau_dashboard_metadata.tableau_personal_access_token_name': + 'tableau_personal_access_token_name', + 'extractor.tableau_dashboard_metadata.tableau_personal_access_token_secret': + 'tableau_personal_access_token_secret', + 'extractor.tableau_dashboard_metadata.excluded_projects': [], + 'extractor.tableau_dashboard_metadata.cluster': 'tableau_dashboard_cluster', + 'extractor.tableau_dashboard_metadata.database': 'tableau_dashboard_database', + 'extractor.tableau_dashboard_metadata.transformer.timestamp_str_to_epoch.timestamp_format': + '%Y-%m-%dT%H:%M:%SZ', + + }) + + extractor = TableauDashboardExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + record = extractor.extract() + self.assertEqual(record.dashboard_id, 'Test Workbook') + self.assertEqual(record.dashboard_name, 'Test Workbook') + self.assertEqual(record.dashboard_group_id, 'Test Project') + self.assertEqual(record.dashboard_group, 'Test Project') + self.assertEqual(record.product, 'tableau') + self.assertEqual(record.cluster, 'tableau_dashboard_cluster') + self.assertEqual(record.dashboard_group_url, 'https://tableau_base_url/#/projects/123') + self.assertEqual(record.dashboard_url, 'https://tableau_base_url/#/workbooks/456/views') + self.assertEqual(record.created_timestamp, 1586323921) + + record = extractor.extract() + self.assertIsNone(record) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_last_modified_extractor.py b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_last_modified_extractor.py new file mode 100644 index 0000000000..6a15794bfe --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_last_modified_extractor.py @@ -0,0 +1,86 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.tableau.tableau_dashboard_last_modified_extractor import ( + TableauDashboardLastModifiedExtractor, +) +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardAuth, TableauGraphQLApiExtractor, +) + +logging.basicConfig(level=logging.INFO) + + +def mock_query(*_args: Any, **_kwargs: Any) -> Dict[str, Any]: + return { + 'workbooks': [ + { + 'id': 'fake-workbook-id', + 'name': 'Test Workbook', + 'projectName': 'Test Project', + 'updatedAt': '2020-08-04T20:16:05Z', + 'projectVizportalUrlId': 123, + 'vizportalUrlId': 456 + }, + { + 'id': 'fake-workbook-id', + 'name': None, + 'projectName': None, + 'createdAt': '2020-08-04T20:16:05Z', + 'projectVizportalUrlId': 123, + 'vizportalUrlId': 456 + } + ] + } + + +def mock_token(*_args: Any, **_kwargs: Any) -> str: + return '123-abc' + + +class TestTableauDashboardLastModified(unittest.TestCase): + + @patch.object(TableauDashboardAuth, '_authenticate', mock_token) + @patch.object(TableauGraphQLApiExtractor, 'execute_query', mock_query) + def test_dashboard_last_modified_extractor(self) -> None: + + config = ConfigFactory.from_dict({ + 'extractor.tableau_dashboard_last_modified.api_base_url': 'api_base_url', + 'extractor.tableau_dashboard_last_modified.api_version': 'tableau_api_version', + 'extractor.tableau_dashboard_last_modified.site_name': 'tableau_site_name', + 'extractor.tableau_dashboard_last_modified.tableau_personal_access_token_name': + 'tableau_personal_access_token_name', + 'extractor.tableau_dashboard_last_modified.tableau_personal_access_token_secret': + 'tableau_personal_access_token_secret', + 'extractor.tableau_dashboard_last_modified.excluded_projects': [], + 'extractor.tableau_dashboard_last_modified.cluster': 'tableau_dashboard_cluster', + 'extractor.tableau_dashboard_last_modified.database': 'tableau_dashboard_database', + 'extractor.tableau_dashboard_last_modified.transformer.timestamp_str_to_epoch.timestamp_format': + '%Y-%m-%dT%H:%M:%SZ', + + }) + + extractor = TableauDashboardLastModifiedExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + record = extractor.extract() + self.assertEqual(record._dashboard_id, 'Test Workbook') + self.assertEqual(record._dashboard_group_id, 'Test Project') + self.assertEqual(record._product, 'tableau') + self.assertEqual(record._cluster, 'tableau_dashboard_cluster') + self.assertEqual(record._last_modified_timestamp, 1596572165) + + record = extractor.extract() + self.assertIsNone(record) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_query_extractor.py b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_query_extractor.py new file mode 100644 index 0000000000..34fb30ac83 --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_query_extractor.py @@ -0,0 +1,75 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.tableau.tableau_dashboard_query_extractor import TableauDashboardQueryExtractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardAuth, TableauGraphQLApiExtractor, +) + +logging.basicConfig(level=logging.INFO) + + +def mock_query(*_args: Any, **_kwargs: Any) -> Dict[str, Any]: + return { + 'customSQLTables': [ + { + 'id': 'fake-query-id', + 'name': 'Test Query', + 'query': 'SELECT * FROM foo', + 'downstreamWorkbooks': [ + { + 'name': 'Test Workbook', + 'projectName': 'Test Project' + } + ] + } + ] + } + + +def mock_token(*_args: Any, **_kwargs: Any) -> str: + return '123-abc' + + +class TestTableauDashboardQuery(unittest.TestCase): + + @patch.object(TableauDashboardAuth, '_authenticate', mock_token) + @patch.object(TableauGraphQLApiExtractor, 'execute_query', mock_query) + def test_dashboard_query_extractor(self) -> None: + + config = ConfigFactory.from_dict({ + 'extractor.tableau_dashboard_query.api_base_url': 'api_base_url', + 'extractor.tableau_dashboard_query.api_version': 'tableau_api_version', + 'extractor.tableau_dashboard_query.site_name': 'tableau_site_name', + 'extractor.tableau_dashboard_query.tableau_personal_access_token_name': + 'tableau_personal_access_token_name', + 'extractor.tableau_dashboard_query.tableau_personal_access_token_secret': + 'tableau_personal_access_token_secret', + 'extractor.tableau_dashboard_query.excluded_projects': [], + 'extractor.tableau_dashboard_query.cluster': 'tableau_dashboard_cluster', + 'extractor.tableau_dashboard_query.database': 'tableau_dashboard_database', + 'extractor.tableau_dashboard_query.transformer.timestamp_str_to_epoch.timestamp_format': + '%Y-%m-%dT%H:%M:%SZ', + + }) + + extractor = TableauDashboardQueryExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + record = extractor.extract() + + self.assertEqual(record._query_name, 'Test Query') + self.assertEqual(record._query_text, 'SELECT * FROM foo') + self.assertEqual(record._dashboard_id, 'Test Workbook') + self.assertEqual(record._dashboard_group_id, 'Test Project') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_table_extractor.py b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_table_extractor.py new file mode 100644 index 0000000000..d724091cca --- /dev/null +++ b/databuilder/tests/unit/extractor/dashboard/tableau/test_tableau_dashboard_table_extractor.py @@ -0,0 +1,120 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dashboard.tableau.tableau_dashboard_table_extractor import TableauDashboardTableExtractor +from databuilder.extractor.dashboard.tableau.tableau_dashboard_utils import ( + TableauDashboardAuth, TableauGraphQLApiExtractor, +) + +logging.basicConfig(level=logging.INFO) + + +def mock_query(*_args: Any, **_kwargs: Any) -> Dict[str, Any]: + return { + 'workbooks': [ + { + 'name': 'Test Workbook', + 'projectName': 'Test Project', + 'upstreamTables': [ + { + 'name': 'test_table_1', + 'schema': 'test_schema_1', + 'database': { + 'name': 'test_database_1', + 'connectionType': 'redshift' + } + }, + { + 'name': 'test_table_2', + 'schema': 'test_schema_2', + 'database': { + 'name': 'test_database_2', + 'connectionType': 'redshift' + } + } + ] + }, + { + 'name': 'Test Workbook', + 'projectName': 'Test Project', + 'upstreamTables': [ + { + 'name': 'test_table_1', + 'schema': 'test_schema_1', + 'database': { + 'name': 'test_database_1', + 'connectionType': 'redshift' + } + }, + { + 'name': None, + 'schema': 'test_schema_2', + 'database': { + 'name': 'test_database_2', + 'connectionType': 'redshift' + } + } + ] + } + ] + } + + +def mock_token(*_args: Any, **_kwargs: Any) -> str: + return '123-abc' + + +class TestTableauDashboardTable(unittest.TestCase): + + @patch.object(TableauDashboardAuth, '_authenticate', mock_token) + @patch.object(TableauGraphQLApiExtractor, 'execute_query', mock_query) + def test_dashboard_table_extractor(self) -> None: + + config = ConfigFactory.from_dict({ + 'extractor.tableau_dashboard_table.api_base_url': 'api_base_url', + 'extractor.tableau_dashboard_table.api_version': 'tableau_api_version', + 'extractor.tableau_dashboard_table.site_name': 'tableau_site_name', + 'extractor.tableau_dashboard_table.tableau_personal_access_token_name': + 'tableau_personal_access_token_name', + 'extractor.tableau_dashboard_table.tableau_personal_access_token_secret': + 'tableau_personal_access_token_secret', + 'extractor.tableau_dashboard_table.excluded_projects': [], + 'extractor.tableau_dashboard_table.cluster': 'tableau_dashboard_cluster', + 'extractor.tableau_dashboard_table.database': 'tableau_dashboard_database', + 'extractor.tableau_dashboard_table.transformer.timestamp_str_to_epoch.timestamp_format': + '%Y-%m-%dT%H:%M:%SZ', + + }) + + extractor = TableauDashboardTableExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + record = extractor.extract() + + self.assertEqual(record._dashboard_id, 'Test Workbook') + self.assertEqual(record._dashboard_group_id, 'Test Project') + self.assertEqual(record._product, 'tableau') + self.assertEqual(record._cluster, 'tableau_dashboard_cluster') + self.assertEqual(record._table_ids, [ + 'tableau_dashboard_database://tableau_dashboard_cluster.test_schema_1/test_table_1', + 'tableau_dashboard_database://tableau_dashboard_cluster.test_schema_2/test_table_2']) + + record = extractor.extract() + + self.assertEqual(record._dashboard_id, 'Test Workbook') + self.assertEqual(record._dashboard_group_id, 'Test Project') + self.assertEqual(record._product, 'tableau') + self.assertEqual(record._cluster, 'tableau_dashboard_cluster') + self.assertEqual(record._table_ids, [ + 'tableau_dashboard_database://tableau_dashboard_cluster.test_schema_1/test_table_1']) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/restapi/__init__.py b/databuilder/tests/unit/extractor/restapi/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/restapi/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/restapi/test_rest_api_extractor.py b/databuilder/tests/unit/extractor/restapi/test_rest_api_extractor.py new file mode 100644 index 0000000000..c73231e8d7 --- /dev/null +++ b/databuilder/tests/unit/extractor/restapi/test_rest_api_extractor.py @@ -0,0 +1,51 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.extractor.restapi.rest_api_extractor import ( + MODEL_CLASS, REST_API_QUERY, STATIC_RECORD_DICT, RestAPIExtractor, +) +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.rest_api.base_rest_api_query import RestApiQuerySeed + + +class TestRestAPIExtractor(unittest.TestCase): + + def test_static_data(self) -> None: + + conf = ConfigFactory.from_dict( + { + REST_API_QUERY: RestApiQuerySeed(seed_record=[{'foo': 'bar'}]), + STATIC_RECORD_DICT: {'john': 'doe'} + } + ) + extractor = RestAPIExtractor() + extractor.init(conf=conf) + + record = extractor.extract() + expected = {'foo': 'bar', 'john': 'doe'} + + self.assertDictEqual(expected, record) + + def test_model_construction(self) -> None: + conf = ConfigFactory.from_dict( + { + REST_API_QUERY: RestApiQuerySeed( + seed_record=[{'dashboard_group': 'foo', + 'dashboard_name': 'bar', + 'description': 'john', + 'dashboard_group_description': 'doe'}]), + MODEL_CLASS: 'databuilder.models.dashboard.dashboard_metadata.DashboardMetadata', + } + ) + extractor = RestAPIExtractor() + extractor.init(conf=conf) + + record = extractor.extract() + expected = DashboardMetadata(dashboard_group='foo', dashboard_name='bar', description='john', + dashboard_group_description='doe') + + self.assertEqual(expected.__repr__(), record.__repr__()) diff --git a/databuilder/tests/unit/extractor/test_athena_metadata_extractor.py b/databuilder/tests/unit/extractor/test_athena_metadata_extractor.py new file mode 100644 index 0000000000..653302c473 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_athena_metadata_extractor.py @@ -0,0 +1,260 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.athena_metadata_extractor import AthenaMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestAthenaMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}': 'MY_CATALOG' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = AthenaMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = { + 'schema': 'test_schema', + 'name': 'test_table', + 'description': '', + 'cluster': self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + } + + sql_execute.return_value = [ + self._union({ + 'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0, + 'extras': None + }, table), + self._union({ + 'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1, + 'extras': None + }, table), + self._union({ + 'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2, + 'extras': None + }, table), + self._union({ + 'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3, + 'extras': None + }, table), + self._union({ + 'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': None, + 'col_sort_order': 4, + 'extras': 'partition key' + }, table), + self._union({ + 'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5, + 'extras': None + }, table) + ] + + extractor = AthenaMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('athena', + self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + 'test_schema', + 'test_table', '', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'partition key', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': '', + 'cluster': self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + } + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': '', + 'cluster': self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + } + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': '', + 'cluster': self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0, + 'extras': None}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1, + 'extras': None}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2, + 'extras': None}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3, + 'extras': None}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': '', + 'col_sort_order': 4, + 'extras': 'partition key'}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5, + 'extras': None}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0, + 'extras': None}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1, + 'extras': None}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0, + 'extras': None}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1, + 'extras': None}, table2) + ] + + extractor = AthenaMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('athena', + self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + 'test_schema1', 'test_table1', '', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'partition key', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('athena', + self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + 'test_schema1', 'test_table2', '', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('athena', + self.conf[f'extractor.athena_metadata.{AthenaMetadataExtractor.CATALOG_KEY}'], + 'test_schema2', 'test_table3', '', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestAthenaMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + config_dict = { + AthenaMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = AthenaMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_bigquery_metadata_extractor.py b/databuilder/tests/unit/extractor/test_bigquery_metadata_extractor.py new file mode 100644 index 0000000000..9aea92c29a --- /dev/null +++ b/databuilder/tests/unit/extractor/test_bigquery_metadata_extractor.py @@ -0,0 +1,386 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any + +from mock import Mock, patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.bigquery_metadata_extractor import BigQueryMetadataExtractor +from databuilder.models.table_metadata import TableMetadata + +logging.basicConfig(level=logging.INFO) + +NO_DATASETS = {'kind': 'bigquery#datasetList', 'etag': '1B2M2Y8AsgTpgAmY7PhCfg=='} +ONE_DATASET = { + 'kind': 'bigquery#datasetList', 'etag': 'yScH5WIHeNUBF9b/VKybXA==', + 'datasets': [{ + 'kind': 'bigquery#dataset', + 'id': 'your-project-here:empty', + 'datasetReference': { + 'datasetId': 'empty', + 'projectId': 'your-project-here' + }, + 'location': 'US' + }] +} # noqa +NO_TABLES = {'kind': 'bigquery#tableList', 'etag': '1B2M2Y8AsgTpgAmY7PhCfg==', 'totalItems': 0} +ONE_TABLE = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.nested_recs', + 'tableReference': { + 'projectId': 'your-project-here', + 'datasetId': 'fdgdfgh', + 'tableId': 'nested_recs' + }, + 'type': 'TABLE', + 'creationTime': '1557578974009' + }], + 'totalItems': 1 +} # noqa +ONE_VIEW = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.abab', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'abab'}, + 'type': 'VIEW', + 'view': {'useLegacySql': False}, + 'creationTime': '1557577874991' + }], + 'totalItems': 1 +} # noqa +TIME_PARTITIONED = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.other', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'other'}, + 'type': 'TABLE', + 'timePartitioning': {'type': 'DAY', 'requirePartitionFilter': False}, + 'creationTime': '1557577779306' + }], + 'totalItems': 1 +} # noqa +TABLE_DATE_RANGE = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', 'id': 'your-project-here:fdgdfgh.other_20190101', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'date_range_20190101'}, + 'type': 'TABLE', + 'creationTime': '1557577779306' + }, { + 'kind': 'bigquery#table', 'id': 'your-project-here:fdgdfgh.other_20190102', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'date_range_20190102'}, + 'type': 'TABLE', + 'creationTime': '1557577779306' + }], + 'totalItems': 2 +} # noqa +TABLE_DATA = { + 'kind': 'bigquery#table', 'etag': 'Hzc/56Rp9VR4Y6jhZApD/g==', 'id': 'your-project-here:fdgdfgh.test', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/your-project-here/datasets/fdgdfgh/tables/test', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'test'}, + 'schema': { + 'fields': [{'name': 'test', 'type': 'STRING', 'description': 'some_description'}, + {'name': 'test2', 'type': 'INTEGER'}, + {'name': 'test3', 'type': 'FLOAT', 'description': 'another description'}, + {'name': 'test4', 'type': 'BOOLEAN'}, + {'name': 'test5', 'type': 'DATETIME'}] + }, + 'numBytes': '0', + 'numLongTermBytes': '0', + 'numRows': '0', + 'creationTime': '1557577756303', + 'lastModifiedTime': '1557577756370', + 'type': 'TABLE', + 'location': 'EU' +} # noqa +NO_SCHEMA = { + 'kind': 'bigquery#table', 'etag': 'Hzc/56Rp9VR4Y6jhZApD/g==', 'id': 'your-project-here:fdgdfgh.no_schema', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/your-project-here/datasets/fdgdfgh/tables/no_schema', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'no_schema'}, + 'numBytes': '0', + 'numLongTermBytes': '0', + 'numRows': '0', + 'creationTime': '1557577756303', + 'lastModifiedTime': '1557577756370', + 'type': 'TABLE', + 'location': 'EU' +} # noqa +NO_COLS = { + 'kind': 'bigquery#table', 'etag': 'Hzc/56Rp9VR4Y6jhZApD/g==', 'id': 'your-project-here:fdgdfgh.no_columns', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/your-project-here/datasets/fdgdfgh/tables/no_columns', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'no_columns'}, + 'schema': {}, + 'numBytes': '0', + 'numLongTermBytes': '0', + 'numRows': '0', + 'creationTime': '1557577756303', + 'lastModifiedTime': '1557577756370', + 'type': 'TABLE', + 'location': 'EU' +} # noqa +VIEW_DATA = { + 'kind': 'bigquery#table', 'etag': 'E6+jjbQ/HsegSNpTEgELUA==', 'id': 'gerard-cloud-2:fdgdfgh.abab', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/gerard-cloud-2/datasets/fdgdfgh/tables/abab', + 'tableReference': {'projectId': 'gerard-cloud-2', 'datasetId': 'fdgdfgh', 'tableId': 'abab'}, + 'schema': { + 'fields': [ + {'name': 'test', 'type': 'STRING'}, + {'name': 'test2', 'type': 'INTEGER'}, + {'name': 'test3', 'type': 'FLOAT'}, + {'name': 'test4', 'type': 'BOOLEAN'}, + {'name': 'test5', 'type': 'DATETIME'}] + }, + 'numBytes': '0', + 'numLongTermBytes': '0', + 'numRows': '0', + 'creationTime': '1557577874991', + 'lastModifiedTime': '1557577874991', + 'type': 'VIEW', + 'view': {'query': 'SELECT * from `gerard-cloud-2.fdgdfgh.test`', 'useLegacySql': False}, + 'location': 'EU' +} # noqa +NESTED_DATA = { + 'kind': 'bigquery#table', 'etag': 'Hzc/56Rp9VR4Y6jhZApD/g==', 'id': 'your-project-here:fdgdfgh.test', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/your-project-here/datasets/fdgdfgh/tables/test', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'test'}, + 'schema': { + 'fields': [{ + 'name': 'nested', 'type': 'RECORD', + 'fields': [{ + 'name': 'nested2', 'type': 'RECORD', + 'fields': [{'name': 'ahah', 'type': 'STRING'}] + }] + }] + }, + 'type': 'TABLE', + 'location': 'EU' +} # noqa + +REPEATED_DATA = { + 'kind': 'bigquery#table', 'etag': 'Hzc/56Rp9VR4Y6jhZApD/g==', 'id': 'your-project-here:fdgdfgh.test', + 'selfLink': 'https://www.googleapis.com/bigquery/v2/projects/your-project-here/datasets/fdgdfgh/tables/test', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'test'}, + 'schema': { + 'fields': [{ + 'name': 'nested', 'type': 'RECORD', + 'fields': [{ + 'name': 'nested2', 'type': 'RECORD', + 'fields': [{'name': 'repeated', 'type': 'STRING', 'mode': 'REPEATED'}] + }] + }] + }, + 'type': 'TABLE', + 'location': 'EU' +} # noqa + +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + + +class MockBigQueryClient(): + def __init__(self, + dataset_list_data: Any, + table_list_data: Any, + table_data: Any + ) -> None: + self.ds_execute = Mock() + self.ds_execute.execute.return_value = dataset_list_data + self.ds_list = Mock() + self.ds_list.list.return_value = self.ds_execute + self.list_execute = Mock() + self.list_execute.execute.return_value = table_list_data + self.get_execute = Mock() + self.get_execute.execute.return_value = table_data + self.tables_method = Mock() + self.tables_method.list.return_value = self.list_execute + self.tables_method.get.return_value = self.get_execute + + def datasets(self) -> Any: + return self.ds_list + + def tables(self) -> Any: + return self.tables_method + + +# Patch fallback auth method to avoid actually calling google API +@patch('google.auth.default', lambda scopes: ['dummy', 'dummy']) +class TestBigQueryMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + config_dict = { + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.PROJECT_ID_KEY}': 'your-project-here' + } + self.conf = ConfigFactory.from_dict(config_dict) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_can_handle_datasets(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(NO_DATASETS, None, None) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_empty_dataset(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, NO_TABLES, None) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_accepts_dataset_filter_by_label(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.FILTER_KEY}': 'label.key:value' + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, TABLE_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsInstance(result, TableMetadata) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_without_schema(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, NO_SCHEMA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.name, 'nested_recs') + self.assertEqual(result.description, None) + self.assertEqual(result.columns, []) + self.assertEqual(result.is_view, False) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_without_columns(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, NO_COLS) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.name, 'nested_recs') + self.assertEqual(result.description, None) + self.assertEqual(result.columns, []) + self.assertEqual(result.is_view, False) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_view(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_VIEW, VIEW_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsInstance(result, TableMetadata) + self.assertEqual(result.is_view, True) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_normal_table(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, TABLE_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.name, 'nested_recs') + self.assertEqual(result.description, None) + + first_col = result.columns[0] + self.assertEqual(first_col.name, 'test') + self.assertEqual(first_col.type, 'STRING') + self.assertEqual(first_col.description.text, 'some_description') + self.assertEqual(result.is_view, False) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_with_nested_records(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, NESTED_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + first_col = result.columns[0] + self.assertEqual(first_col.name, 'nested') + self.assertEqual(first_col.type, 'RECORD') + second_col = result.columns[1] + self.assertEqual(second_col.name, 'nested.nested2') + self.assertEqual(second_col.type, 'RECORD') + third_col = result.columns[2] + self.assertEqual(third_col.name, 'nested.nested2.ahah') + self.assertEqual(third_col.type, 'STRING') + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_keypath_and_pagesize_can_be_set(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.PAGE_SIZE_KEY}': 200, + f'extractor.bigquery_table_metadata.{BigQueryMetadataExtractor.KEY_PATH_KEY}': '/tmp/doesnotexist', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, TABLE_DATA) + extractor = BigQueryMetadataExtractor() + + with self.assertRaises(FileNotFoundError): + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_part_of_table_date_range(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TABLE_DATE_RANGE, TABLE_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + count = 0 + result = extractor.extract() + table_name = result.name + while result: + count += 1 + result = extractor.extract() + + self.assertEqual(count, 1) + self.assertEqual(table_name, 'date_range_') + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_with_repeated_records(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, REPEATED_DATA) + extractor = BigQueryMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + first_col = result.columns[0] + self.assertEqual(first_col.name, 'nested') + self.assertEqual(first_col.type, 'RECORD') + second_col = result.columns[1] + self.assertEqual(second_col.name, 'nested.nested2') + self.assertEqual(second_col.type, 'RECORD') + third_col = result.columns[2] + self.assertEqual(third_col.name, 'nested.nested2.repeated') + self.assertEqual(third_col.type, 'STRING:REPEATED') diff --git a/databuilder/tests/unit/extractor/test_bigquery_usage_extractor.py b/databuilder/tests/unit/extractor/test_bigquery_usage_extractor.py new file mode 100644 index 0000000000..6df4386aea --- /dev/null +++ b/databuilder/tests/unit/extractor/test_bigquery_usage_extractor.py @@ -0,0 +1,400 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import tempfile +import unittest +from typing import Any + +from mock import Mock, patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.bigquery_usage_extractor import BigQueryTableUsageExtractor, TableColumnUsageTuple + +CORRECT_DATA = { + "entries": [{ + "protoPayload": { + "@type": "type.googleapis.com/google.cloud.audit.AuditLog", + "status": {}, + "authenticationInfo": { + "principalEmail": "your-user-here@test.com" + }, + "serviceName": "bigquery.googleapis.com", + "methodName": "jobservice.jobcompleted", + "resourceName": "projects/your-project-here/jobs/bquxjob_758c08d1_16a96889839", + "serviceData": { + "@type": "type.googleapis.com/google.cloud.bigquery.logging.v1.AuditData", + "jobCompletedEvent": { + "eventName": "query_job_completed", + "job": { + "jobName": { + "projectId": "your-project-here", + "jobId": "bquxjob_758c08d1_16a96889839", + "location": "US" + }, + "jobConfiguration": { + "query": { + "query": "select descript from " + "`bigquery-public-data.austin_incidents.incidents_2008`\n", + "destinationTable": { + "projectId": "your-project-here", + "datasetId": "_07147a061ddfd6dcaf246cfc5e858a0ccefa7080", + "tableId": "anon1dd83635c62357091e55a5f76fb62d7deebcfa4c" + }, + "createDisposition": "CREATE_IF_NEEDED", + "writeDisposition": "WRITE_TRUNCATE", + "defaultDataset": {}, + "queryPriority": "QUERY_INTERACTIVE", + "statementType": "SELECT" + } + }, + "jobStatus": { + "state": "DONE", + "error": {} + }, + "jobStatistics": { + "createTime": "2019-05-08T08:22:56.349Z", + "startTime": "2019-05-08T08:22:56.660Z", + "endTime": "2019-05-08T08:23:00.049Z", + "totalProcessedBytes": "3637807", + "totalBilledBytes": "10485760", + "billingTier": 1, + "totalSlotMs": "452", + "referencedTables": [ + { + "projectId": "bigquery-public-data", + "datasetId": "austin_incidents", + "tableId": "incidents_2008" + } + ], + "totalTablesProcessed": 1, + "queryOutputRowCount": "179524" + } + } + } + } + }, + "insertId": "-jyqvjse6lwjz", + "resource": { + "type": "bigquery_resource", + "labels": { + "project_id": "your-project-here" + } + }, + "timestamp": "2019-05-08T08:23:00.061Z", + "severity": "INFO", + "logName": "projects/your-project-here/logs/cloudaudit.googleapis.com%2Fdata_access", + "receiveTimestamp": "2019-05-08T08:23:00.310709609Z" + }] +} # noqa + +FAILURE = { + "entries": [{ + "protoPayload": { + "authenticationInfo": { + "principalEmail": "your-user-here@test.com" + }, + "methodName": "jobservice.jobcompleted", + "serviceData": { + "jobCompletedEvent": { + "job": { + "jobStatus": { + "state": "DONE", + "error": { + "code": 11, + "message": "Some descriptive error message" + } + }, + "jobStatistics": { + "createTime": "2019-05-08T08:22:56.349Z", + "startTime": "2019-05-08T08:22:56.660Z", + "endTime": "2019-05-08T08:23:00.049Z", + "totalProcessedBytes": "3637807", + "totalBilledBytes": "10485760", + "referencedTables": [ + { + "projectId": "bigquery-public-data", + "datasetId": "austin_incidents", + "tableId": "incidents_2008" + } + ] + } + } + } + }, + }, + }] +} # noqa + +# An empty dict will be ignored, but putting in nextPageToken causes the test +# to loop infinitely, so we need a bogus key/value to ensure that we will try +# to read entries +NO_ENTRIES = {'key': 'value'} # noqa + +KEYFILE_DATA = """ +ewogICJ0eXBlIjogInNlcnZpY2VfYWNjb3VudCIsCiAgInByb2plY3RfaWQiOiAieW91ci1wcm9q +ZWN0LWhlcmUiLAogICJwcml2YXRlX2tleV9pZCI6ICJiMDQ0N2U1ODEyYTg5ZTAyOTgxYjRkMWE1 +YjE1N2NlNzZkOWJlZTc3IiwKICAicHJpdmF0ZV9rZXkiOiAiLS0tLS1CRUdJTiBQUklWQVRFIEtF +WS0tLS0tXG5NSUlFdkFJQkFEQU5CZ2txaGtpRzl3MEJBUUVGQUFTQ0JLWXdnZ1NpQWdFQUFvSUJB +UUM1UzBYRWtHY2NuOEsxXG5ZbHhRbXlhRWFZK2grYnRacHRVWjJiK2J1cTluNExKU3I3eTdPQWll +ZjBWazIyQnc1TFRsUXRQSUtNVkh6MzJMXG5Ld0lJYmY5Wkwzamd5UC9hNHIveHVhMVdzNFF2YVkz +TGoxRG1ITm40L3hQNXdDY0VscHIxV2RXL05VZ1RQV1A2XG5LZnVDdHhyQTJxbHJNazhyYklXVTRm +WTAzQmFqdzNHT0p4VDBvbXlCVmdGSzJTdGRFUVVYMm9YQVdSNXJyR21qXG5qWTNzb3lNU0NwSWtT +b0h4b1BrVEM0VzZ2a3dJRlk4SUkwbmhsWUZHc3FiZjdkbTBLVEZmVVh5SUFTOHd6RCtlXG54UFVQ +V3k0UXA5cTVyNTVPRmlxdWt3TGNZei9BQXFpYTU3KzhURmhiWXcwUXNsZ2xSaWFLWkVhQyt4M0pD +OEhuXG5KajY2WE5mTEFnTUJBQUVDZ2dFQVMyNFlGYi9QS2ZqamM2RjZBUnBYNExsMFRqVHlqcmw2 +c001UzBSdDdRbWRYXG5VSS9YM2NNZXh4NzZhZWRnYURUQ2F6MzhKdFJxRXlTbGI5enZNKzFMY013 +QmdraHcxM05OUGlNZkxGZGg3VWNrXG5BUVR6b3VtRjFuWklkSGhEcWZ1QlUzWGhyTGdOQWtBUWpn +cy9KdVJSVU1iekJ2OXcrVFZ4WDcxbzAvWHdoWE5kXG5kSWlWdE1TbnFWQ0J2cEp3ZXBoR3FxNGQ3 +VEIzb2F3UUg1QkFGeHk5NGpoT0dwaVFWYW8yQmtPdEVyVVBQYjkrXG5vRzByZTM3WHVtQzZRWENv +VSs4Zm4vcE1YVWVOUitXSm5tY1lndVZqWDl6QzJ3MU13cmVmOFVKa1Q4SHJxZ09KXG5sWnNFcVJr +aHBYUFVzdmt2dWxQTWQ3TitJdlFvYTh0N3ZaZFkrR1lMdVFLQmdRRHd2enY0alhVUStIU1RaVm1p +XG5hQmNMVGRMRE5WNlpuT25aTEhxbDZaQmloTUhZNi9qS2xDN1hqWGJaQ2NqS05MMkE1am9mQ0d5 +bHFhNFRrZnArXG5rYmJKQ29KS2tFY1pSWGQ3NEdXb0J1V2d3enY2WWFkcDNxS2x0RndhM1FjMkJ3 +SlNlazkrTzd6OGs2d0dvclZJXG5OK3ZNMVd3OWJPa1VaaXh4T2g2V2ZKSTl6UUtCZ1FERkNLQXZ2 +b3FUQnErMnovazhLYy9lTHVRdThPWWNXVm9GXG55eXprOTN2QnBXcEVPT1hybnNsUFFtQldUdTN5 +UWpRN08zd2t1c0g3VUtJQTg0MDVHbDlwbmJvTmlaSVdBRlpvXG4vVWlVVm5aa3pvZER5Tk9PUjBm +UW5zM1BaeE5peklSSjh2Mm93a2d3MExFYWEwaWUyNU92bFJmQ2pmYlVZL0EzXG5wbU9SVkdFVDl3 +S0JnR0Zab3lHRjZoRzd0a0FvR28vT3NZclRwR2RsZkdSM2pDUlNsU0hrQ1l1ZERWbnZTY0o1XG5H +MXYwaTF1R1ZsaFY3VTlqU1p0azU3SXhvLytyNXZRcGJoVnJsM1laVTNiSG5XSk5RaTRvNDlBWFFu +aWo1bk9zXG5JRzhMT0xkd0swdFFtRUxMekx0SjRzanIyZ013NWtkV3ZaWXRzMEEvZXh6Um1DVU5F +SE5mMmk3OUFvR0FESVpkXG4yR3NlVi9aRzJUSWpQOFhRcHVrSUxFdTM5UGxoRlpreXcyTlFCS0ZG +UGd6MzRLQjVYNFp5cFVuaktsRTNETVRkXG5RV0IxMEVueDRtbVpBcFpBbG5BbVVaSDdMVmJjSjFS +aWRydUFUeXdwd1E5VkUyaElrbVJsNU5kQ2pqYzkrWTF1XG52bm1MS1Q4NjR0a0xCcjRpaHpqTkI5 +c0tZN251blRzQWZVNkYxVVVDZ1lBMmdlMFdiVEVwRlBuN05YYjZ4citiXG5QK1RFVEVWZzhRS0Z1 +OUtHVk03NXI5dmhYblNicmphbGVCSzJFQzBLK2F2d2hHTTd3eXRqM0FrTjRac2NKNWltXG5VZTBw +Z3pVSE1RSVI1OWlGVmt5WVVjZnZMSERZU0xmeW9QVU5RWWduVXBKYlZOczZtWFRqQ3o2UERrb0tX +ZzcyXG4rS3p4RWhubWJzY0NiSFRpQ08wNEtBPT1cbi0tLS0tRU5EIFBSSVZBVEUgS0VZLS0tLS1c +biIsCiAgImNsaWVudF9lbWFpbCI6ICJ0ZXN0LTE2MkB5b3VyLXByb2plY3QtaGVyZS5pYW0uZ3Nl +cnZpY2VhY2NvdW50LmNvbSIsCiAgImNsaWVudF9pZCI6ICIxMDg2NTMzMjY0MzE1NDU2ODg3MTAi +LAogICJhdXRoX3VyaSI6ICJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20vby9vYXV0aDIvYXV0 +aCIsCiAgInRva2VuX3VyaSI6ICJodHRwczovL29hdXRoMi5nb29nbGVhcGlzLmNvbS90b2tlbiIs +CiAgImF1dGhfcHJvdmlkZXJfeDUwOV9jZXJ0X3VybCI6ICJodHRwczovL3d3dy5nb29nbGVhcGlz +LmNvbS9vYXV0aDIvdjEvY2VydHMiLAogICJjbGllbnRfeDUwOV9jZXJ0X3VybCI6ICJodHRwczov +L3d3dy5nb29nbGVhcGlzLmNvbS9yb2JvdC92MS9tZXRhZGF0YS94NTA5L3Rlc3QtMTYyJTQweW91 +ci1wcm9qZWN0LWhlcmUuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iCn0KCgo= +""" + + +class MockLoggingClient(): + def __init__(self, data: Any) -> None: + self.data = data + self.a = Mock() + self.a.execute.return_value = self.data + self.b = Mock() + self.b.list.return_value = self.a + + def entries(self) -> Any: + return self.b + + +# Patch fallback auth method to avoid actually calling google API +@patch('google.auth.default', lambda scopes: ['dummy', 'dummy']) +class TestBigqueryUsageExtractor(unittest.TestCase): + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_basic_extraction(self, mock_build: Any) -> None: + """ + Test Extraction using mock class + """ + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + assert result is not None + self.assertIsInstance(result, tuple) + + (key, value) = result + self.assertIsInstance(key, TableColumnUsageTuple) + self.assertIsInstance(value, int) + + self.assertEqual(key.database, 'bigquery') + self.assertEqual(key.cluster, 'bigquery-public-data') + self.assertEqual(key.schema, 'austin_incidents') + self.assertEqual(key.table, 'incidents_2008') + self.assertEqual(key.email, 'your-user-here@test.com') + self.assertEqual(value, 1) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_no_entries(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(NO_ENTRIES) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_key_path(self, mock_build: Any) -> None: + """ + Test key_path can be used + """ + + with tempfile.NamedTemporaryFile() as keyfile: + # There are many github scanners looking for API / cloud keys, so in order not to get a + # false positive triggering everywhere, I base64 encoded the key. + # This is written to a tempfile as part of this test and then used. + keyfile.write(base64.b64decode(KEYFILE_DATA)) + keyfile.flush() + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.KEY_PATH_KEY}': keyfile.name, + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + args, kwargs = mock_build.call_args + creds = kwargs['http'].credentials + self.assertEqual(creds.project_id, 'your-project-here') + self.assertEqual(creds.service_account_email, 'test-162@your-project-here.iam.gserviceaccount.com') + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_timestamp_pagesize_settings(self, mock_build: Any) -> None: + """ + Test timestamp and pagesize can be set + """ + TIMESTAMP = '2019-01-01T00:00:00.00Z' + PAGESIZE = 215 + + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.TIMESTAMP_KEY}': TIMESTAMP, + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PAGE_SIZE_KEY}': PAGESIZE, + } + conf = ConfigFactory.from_dict(config_dict) + + client = MockLoggingClient(CORRECT_DATA) + mock_build.return_value = client + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + args, kwargs = client.b.list.call_args + body = kwargs['body'] + + self.assertEqual(body['pageSize'], PAGESIZE) + self.assertEqual(TIMESTAMP in body['filter'], True) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_failed_jobs_should_not_be_counted(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + } + conf = ConfigFactory.from_dict(config_dict) + + client = MockLoggingClient(FAILURE) + mock_build.return_value = client + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_email_filter_not_counted(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.EMAIL_PATTERN}': 'emailFilter', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_email_filter_counted(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'bigquery-public-data', + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.EMAIL_PATTERN}': '.*@test.com.*', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + assert result is not None + self.assertIsInstance(result, tuple) + + (key, value) = result + self.assertIsInstance(key, TableColumnUsageTuple) + self.assertIsInstance(value, int) + + self.assertEqual(key.database, 'bigquery') + self.assertEqual(key.cluster, 'bigquery-public-data') + self.assertEqual(key.schema, 'austin_incidents') + self.assertEqual(key.table, 'incidents_2008') + self.assertEqual(key.email, 'your-user-here@test.com') + self.assertEqual(value, 1) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_not_counting_referenced_table_belonging_to_different_project(self, mock_build: Any) -> None: + """ + Test result when referenced table belongs to a project different from the PROJECT_ID_KEY of the extractor + """ + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'your-project-here', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + result = extractor.extract() + assert result is None + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_counting_referenced_table_belonging_to_different_project(self, mock_build: Any) -> None: + """ + Test result when referenced table belongs to a project different from the PROJECT_ID_KEY of the extractor + and COUNT_READS_ONLY_FROM_PROJECT is set to False + """ + config_dict = { + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_table_usage.{BigQueryTableUsageExtractor.COUNT_READS_ONLY_FROM_PROJECT_ID_KEY}': False, + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockLoggingClient(CORRECT_DATA) + extractor = BigQueryTableUsageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + result = extractor.extract() + assert result is not None + self.assertIsInstance(result, tuple) + + (key, value) = result + self.assertIsInstance(key, TableColumnUsageTuple) + self.assertIsInstance(value, int) + + self.assertEqual(key.database, 'bigquery') + self.assertEqual(key.cluster, 'bigquery-public-data') + self.assertEqual(key.schema, 'austin_incidents') + self.assertEqual(key.table, 'incidents_2008') + self.assertEqual(key.email, 'your-user-here@test.com') + self.assertEqual(value, 1) diff --git a/databuilder/tests/unit/extractor/test_bigquery_watermark_extractor.py b/databuilder/tests/unit/extractor/test_bigquery_watermark_extractor.py new file mode 100644 index 0000000000..2c3a3a5116 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_bigquery_watermark_extractor.py @@ -0,0 +1,308 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from datetime import datetime +from typing import Any + +from mock import Mock, patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.bigquery_watermark_extractor import BigQueryWatermarkExtractor + +logging.basicConfig(level=logging.INFO) + +NO_DATASETS = {'kind': 'bigquery#datasetList', 'etag': '1B2M2Y8AsgTpgAmY7PhCfg=='} +ONE_DATASET = { + 'kind': 'bigquery#datasetList', 'etag': 'yScH5WIHeNUBF9b/VKybXA==', + 'datasets': [{ + 'kind': 'bigquery#dataset', + 'id': 'your-project-here:empty', + 'datasetReference': {'datasetId': 'empty', 'projectId': 'your-project-here'}, + 'location': 'US' + }] +} # noqa +NO_TABLES = {'kind': 'bigquery#tableList', 'etag': '1B2M2Y8AsgTpgAmY7PhCfg==', 'totalItems': 0} +ONE_TABLE = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.nested_recs', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'nested_recs'}, + 'type': 'TABLE', + 'creationTime': '1557578974009' + }], + 'totalItems': 1 +} # noqa +TIME_PARTITIONED = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.other', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'other'}, + 'type': 'TABLE', + 'timePartitioning': {'type': 'DAY', 'requirePartitionFilter': False}, + 'creationTime': '1557577779306' + }], + 'totalItems': 1 +} # noqa +TIME_PARTITIONED_WITH_FIELD = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.other', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'other'}, + 'type': 'TABLE', + 'timePartitioning': {'type': 'DAY', 'field': 'processed_date', 'requirePartitionFilter': False}, + 'creationTime': '1557577779306' + }], + 'totalItems': 1 +} # noqa +TABLE_DATE_RANGE = { + 'kind': 'bigquery#tableList', 'etag': 'Iaqrz2TCDIANAOD/Xerkjw==', + 'tables': [{ + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.other_20190101', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'date_range_20190101'}, + 'type': 'TABLE', + 'creationTime': '1557577779306' + }, { + 'kind': 'bigquery#table', + 'id': 'your-project-here:fdgdfgh.other_20190102', + 'tableReference': {'projectId': 'your-project-here', 'datasetId': 'fdgdfgh', 'tableId': 'date_range_20190102'}, + 'type': 'TABLE', + 'creationTime': '1557577779306' + }], + 'totalItems': 2 +} # noqa +PARTITION_DATA = { + 'kind': 'bigquery#queryResponse', + 'schema': { + 'fields': [{ + 'name': 'partition_id', + 'type': 'STRING', + 'mode': 'NULLABLE' + }, { + 'name': 'creation_time', + 'type': 'TIMESTAMP', + 'mode': 'NULLABLE' + }] + }, + 'jobReference': {'projectId': 'your-project-here', 'jobId': 'job_bfTRGj3Lv0tRjcrotXbZSgMCpNhY', 'location': 'EU'}, + 'totalRows': '3', + 'rows': [{'f': [{'v': '20180802'}, {'v': '1.547512241348E9'}]}, + {'f': [{'v': '20180803'}, {'v': '1.547512241348E9'}]}, + {'f': [{'v': '20180804'}, {'v': '1.547512241348E9'}]}], + 'totalBytesProcessed': '0', + 'jobComplete': True, + 'cacheHit': False +} # noqa + +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + + +class MockBigQueryClient(): + def __init__(self, + dataset_list_data: Any, + table_list_data: Any, + partition_data: Any + ) -> None: + self.list_execute = Mock() + self.list_execute.execute.return_value = table_list_data + self.tables_method = Mock() + self.tables_method.list.return_value = self.list_execute + self.ds_execute = Mock() + self.ds_execute.execute.return_value = dataset_list_data + self.ds_list = Mock() + self.ds_list.list.return_value = self.ds_execute + self.query_execute = Mock() + self.query_execute.execute.return_value = partition_data + self.jobs_query = Mock() + self.jobs_query.query.return_value = self.query_execute + + def datasets(self) -> Any: + return self.ds_list + + def tables(self) -> Any: + return self.tables_method + + def jobs(self) -> Any: + return self.jobs_query + + +# Patch fallback auth method to avoid actually calling google API +@patch('google.auth.default', lambda scopes: ['dummy', 'dummy']) +class TestBigQueryWatermarkExtractor(unittest.TestCase): + def setUp(self) -> None: + config_dict = { + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.PROJECT_ID_KEY}': + 'your-project-here'} + self.conf = ConfigFactory.from_dict(config_dict) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_can_handle_no_datasets(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(NO_DATASETS, None, None) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_empty_dataset(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, NO_TABLES, None) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_without_partitions(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, None) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_with_default_partitions(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TIME_PARTITIONED, PARTITION_DATA) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertEqual(result.part_type, 'low_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('_PARTITIONTIME', '20180802')]) + + result = extractor.extract() + self.assertEqual(result.part_type, 'high_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('_PARTITIONTIME', '20180804')]) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_with_field_partitions(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TIME_PARTITIONED_WITH_FIELD, PARTITION_DATA) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + assert result is not None + self.assertEqual(result.part_type, 'low_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('processed_date', '20180802')]) + + result = extractor.extract() + assert result is not None + self.assertEqual(result.part_type, 'high_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('processed_date', '20180804')]) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_keypath_can_be_set(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.KEY_PATH_KEY}': '/tmp/doesnotexist', + } + conf = ConfigFactory.from_dict(config_dict) + + mock_build.return_value = MockBigQueryClient(ONE_DATASET, ONE_TABLE, None) + extractor = BigQueryWatermarkExtractor() + + with self.assertRaises(FileNotFoundError): + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_part_of_table_date_range(self, mock_build: Any) -> None: + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TABLE_DATE_RANGE, None) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + assert result is not None + self.assertEqual(result.part_type, 'low_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'date_range_') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1557577779).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('__table__', '20190101')]) + + result = extractor.extract() + assert result is not None + self.assertEqual(result.part_type, 'high_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'date_range_') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1557577779).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('__table__', '20190102')]) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_creation_time_after_cutoff_time(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.CUTOFF_TIME_KEY}': '2019-05-10T20:10:22Z' + } + conf = ConfigFactory.from_dict(config_dict) + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TIME_PARTITIONED, PARTITION_DATA) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsNone(result) + + @patch('databuilder.extractor.base_bigquery_extractor.build') + def test_table_creation_time_before_cutoff_time(self, mock_build: Any) -> None: + config_dict = { + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.PROJECT_ID_KEY}': 'your-project-here', + f'extractor.bigquery_watermarks.{BigQueryWatermarkExtractor.CUTOFF_TIME_KEY}': '2021-04-27T20:10:22Z' + } + conf = ConfigFactory.from_dict(config_dict) + mock_build.return_value = MockBigQueryClient(ONE_DATASET, TIME_PARTITIONED, PARTITION_DATA) + extractor = BigQueryWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + result = extractor.extract() + assert result is not None + self.assertEqual(result.part_type, 'low_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('_PARTITIONTIME', '20180802')]) + + result = extractor.extract() + self.assertEqual(result.part_type, 'high_watermark') + self.assertEqual(result.database, 'bigquery') + self.assertEqual(result.schema, 'fdgdfgh') + self.assertEqual(result.table, 'other') + self.assertEqual(result.cluster, 'your-project-here') + self.assertEqual(result.create_time, datetime.fromtimestamp(1547512241).strftime('%Y-%m-%d %H:%M:%S')) + self.assertEqual(result.parts, [('_PARTITIONTIME', '20180804')]) diff --git a/databuilder/tests/unit/extractor/test_cassandra_extractor.py b/databuilder/tests/unit/extractor/test_cassandra_extractor.py new file mode 100644 index 0000000000..e938c375d7 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_cassandra_extractor.py @@ -0,0 +1,87 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from collections import OrderedDict +from typing import Any + +from cassandra.metadata import ColumnMetadata as CassandraColumnMetadata +from mock import patch +from pyhocon import ConfigFactory + +from databuilder.extractor.cassandra_extractor import CassandraExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +# patch whole class to avoid actually calling for boto3.client during tests +@patch('cassandra.cluster.Cluster.connect', lambda x: None) +class TestCassandraExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + self.default_conf = ConfigFactory.from_dict({}) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + extractor = CassandraExtractor() + extractor.init(self.default_conf) + + results = extractor.extract() + self.assertEqual(results, None) + + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_keyspaces') + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_tables') + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_columns') + def test_extraction_with_default_conf(self, + mock_columns: Any, + mock_tables: Any, + mock_keyspaces: Any + ) -> None: + mock_keyspaces.return_value = {'test_schema': None} + mock_tables.return_value = {'test_table': None} + columns_dict = OrderedDict() + columns_dict['id'] = CassandraColumnMetadata(None, 'id', 'int') + columns_dict['txt'] = CassandraColumnMetadata(None, 'txt', 'text') + mock_columns.return_value = columns_dict + + extractor = CassandraExtractor() + extractor.init(self.default_conf) + actual = extractor.extract() + expected = TableMetadata('cassandra', 'gold', 'test_schema', 'test_table', None, + [ColumnMetadata('id', None, 'int', 0), + ColumnMetadata('txt', None, 'text', 1)]) + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_keyspaces') + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_tables') + @patch('databuilder.extractor.cassandra_extractor.CassandraExtractor._get_columns') + def test_extraction_with_filter_conf(self, + mock_columns: Any, + mock_tables: Any, + mock_keyspaces: Any + ) -> None: + mock_keyspaces.return_value = {'test_schema': None} + mock_tables.return_value = {'test_table': None} + columns_dict = OrderedDict() + columns_dict['id'] = CassandraColumnMetadata(None, 'id', 'int') + columns_dict['txt'] = CassandraColumnMetadata(None, 'txt', 'text') + mock_columns.return_value = columns_dict + + def filter_function(k: str, t: str) -> bool: + return False if 'test' in k or 'test' in t else False + + conf = ConfigFactory.from_dict({ + CassandraExtractor.FILTER_FUNCTION_KEY: filter_function + }) + + extractor = CassandraExtractor() + extractor.init(conf) + self.assertIsNone(extractor.extract()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_csv_extractor.py b/databuilder/tests/unit/extractor/test_csv_extractor.py new file mode 100644 index 0000000000..066fe1ea45 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_csv_extractor.py @@ -0,0 +1,138 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.csv_extractor import ( + CsvColumnLineageExtractor, CsvExtractor, CsvTableBadgeExtractor, CsvTableColumnExtractor, CsvTableLineageExtractor, + split_badge_list, +) +from databuilder.models.badge import Badge + + +class TestCsvExtractor(unittest.TestCase): + + def test_extraction_with_model_class(self) -> None: + """ + Test Extraction using model class + """ + config_dict = { + f'extractor.csv.{CsvExtractor.FILE_LOCATION}': 'example/sample_data/sample_table.csv', + 'extractor.csv.model_class': 'databuilder.models.table_metadata.TableMetadata', + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = CsvExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.name, 'test_table1') + self.assertEqual(result.description.text, '1st test table') + self.assertEqual(result.database, 'hive') + self.assertEqual(result.cluster, 'gold') + self.assertEqual(result.schema, 'test_schema') + self.assertEqual(result.tags, ['tag1', 'tag2']) + self.assertEqual(result.is_view, 'false') + + result2 = extractor.extract() + self.assertEqual(result2.name, 'test_table2') + self.assertEqual(result2.is_view, 'false') + + result3 = extractor.extract() + self.assertEqual(result3.name, 'test_view1') + self.assertEqual(result3.is_view, 'true') + + def test_extraction_table_badges(self) -> None: + """ + Tests that badges are properly parsed from a CSV file and assigned to a table. + """ + config_dict = { + f'extractor.csvtablebadge.{CsvTableBadgeExtractor.TABLE_FILE_LOCATION}': + 'example/sample_data/sample_table.csv', + f'extractor.csvtablebadge.{CsvTableBadgeExtractor.BADGE_FILE_LOCATION}': + 'example/sample_data/sample_badges.csv', + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = CsvTableBadgeExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result_1 = extractor.extract() + self.assertEqual([b.name for b in result_1.badges], ['beta']) + + result_2 = extractor.extract() + self.assertEqual([b.name for b in result_2.badges], ['json', 'npi']) + + def test_extraction_of_tablecolumn_badges(self) -> None: + """ + Test Extraction using the combined CsvTableModel model class + """ + config_dict = { + f'extractor.csvtablecolumn.{CsvTableColumnExtractor.TABLE_FILE_LOCATION}': + 'example/sample_data/sample_table.csv', + f'extractor.csvtablecolumn.{CsvTableColumnExtractor.COLUMN_FILE_LOCATION}': + 'example/sample_data/sample_col.csv', + } + self.conf = ConfigFactory.from_dict(config_dict) + + extractor = CsvTableColumnExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.name, 'test_table1') + self.assertEqual(result.columns[0].badges, [Badge('pk', 'column')]) + self.assertEqual(result.columns[1].badges, [Badge('pii', 'column')]) + self.assertEqual(result.columns[2].badges, [Badge('fk', 'column'), Badge('pii', 'column')]) + + def test_extraction_table_lineage(self) -> None: + """ + Test table lineage extraction using model class + """ + config_dict = { + f'extractor.csvtablelineage.{CsvTableLineageExtractor.TABLE_LINEAGE_FILE_LOCATION}': + 'example/sample_data/sample_table_lineage.csv' + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = CsvTableLineageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.table_key, 'hive://gold.test_schema/test_table1') + self.assertEqual(result.downstream_deps, ['dynamo://gold.test_schema/test_table2']) + + def test_extraction_column_lineage(self) -> None: + """ + Test column lineage extraction using model class + """ + config_dict = { + f'extractor.csvcolumnlineage.{CsvColumnLineageExtractor.COLUMN_LINEAGE_FILE_LOCATION}': + 'example/sample_data/sample_column_lineage.csv' + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = CsvColumnLineageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.column_key, 'hive://gold.test_schema/test_table1/col1') + self.assertEqual(result.downstream_deps, ['dynamo://gold.test_schema/test_table2/col1']) + + def test_split_badge_list(self) -> None: + """ + Test spliting a string of badges into a list, removing all empty badges. + """ + badge_list_1 = 'badge1' + result_1 = split_badge_list(badges=badge_list_1, separator=',') + self.assertEqual(result_1, ['badge1']) + + badge_list_2 = '' + result_2 = split_badge_list(badges=badge_list_2, separator=',') + self.assertEqual(result_2, []) + + badge_list_3 = 'badge1|badge2|badge3' + result_3 = split_badge_list(badges=badge_list_3, separator='|') + self.assertEqual(result_3, ['badge1', 'badge2', 'badge3']) diff --git a/databuilder/tests/unit/extractor/test_dbt_extractor.py b/databuilder/tests/unit/extractor/test_dbt_extractor.py new file mode 100644 index 0000000000..1371604f80 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_dbt_extractor.py @@ -0,0 +1,355 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest +from typing import ( + Any, Optional, Union, no_type_check, +) + +import pyhocon +import pytest +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.dbt_extractor import DbtExtractor, InvalidDbtInputs +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.table_lineage import TableLineage +from databuilder.models.table_metadata import TableMetadata +from databuilder.models.table_source import TableSource + + +def _extract_until_not_these(extractor: DbtExtractor, + classes: Any) -> Optional[Union[BadgeMetadata, TableLineage, TableMetadata, TableSource]]: + # Move to the next type of extracted class: + r = extractor.extract() + while isinstance(r, tuple(classes)): + r = extractor.extract() + return r + + +class TestCsvExtractor(unittest.TestCase): + + database_name = 'snowflake' + catalog_file_loc = 'example/sample_data/dbt/catalog.json' + manifest_data = 'example/sample_data/dbt/manifest.json' + source_url = 'test_url' + + @no_type_check + def test_extraction_with_model_class(self) -> None: + """ + Test Extraction using model class + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name, + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.SOURCE_URL}': self.source_url + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + # One block of tests for each type of model created + extracted_classes = [] + + result = extractor.extract() + self.assertTrue(isinstance(result, TableMetadata)) + self.assertEqual(result.name, 'fact_third_party_performance') + self.assertEqual(result.description.text, 'the performance for third party vendors loss rate by day.') + self.assertEqual(result.database, self.database_name) + self.assertEqual(result.cluster, 'dbt_demo') + self.assertEqual(result.schema, 'public') + self.assertEqual(result.tags, []) + self.assertEqual(result.is_view, True) + extracted_classes.append(TableMetadata) + + result2 = _extract_until_not_these(extractor, extracted_classes) + self.assertTrue(isinstance(result2, TableSource)) + self.assertEqual(result2.db, self.database_name) + self.assertEqual(result2.cluster, 'dbt_demo') + self.assertEqual(result2.schema, 'public') + self.assertEqual(result2.table, 'fact_third_party_performance') + self.assertEqual(result2.source, 'test_url/models/call_center/fact_third_party_performance.sql') + extracted_classes.append(TableSource) + + result3 = _extract_until_not_these(extractor, extracted_classes) + self.assertTrue(isinstance(result3, BadgeMetadata)) + self.assertEqual(result3.badges, [Badge('finance', 'table'), Badge('certified', 'table')]) + extracted_classes.append(BadgeMetadata) + + result4 = _extract_until_not_these(extractor, extracted_classes) + self.assertTrue(isinstance(result4, TableLineage)) + self.assertEqual(result4.table_key, 'snowflake://dbt_demo.public/fact_catalog_returns') + self.assertEqual(result4.downstream_deps, ['snowflake://dbt_demo.public/fact_third_party_performance']) + extracted_classes.append(TableLineage) + + # Should not be any other unique models created + result5 = _extract_until_not_these(extractor, extracted_classes) + self.assertEqual(result5, None) + + @no_type_check + def test_dbt_file_inputs_as_json_dumps(self) -> None: + """ + Tests to ensure that the same content can be extracted when the manifest.json + and catalog.json are provided as a file location or as a json.dumps() object + """ + config_dict_1 = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name, + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.SOURCE_URL}': self.source_url + } + conf_1 = ConfigFactory.from_dict(config_dict_1) + extractor_1 = DbtExtractor() + extractor_1.init(Scoped.get_scoped_conf(conf=conf_1, scope=extractor_1.get_scope())) + + with open(self.catalog_file_loc, 'r') as f: + catalog_as_json = json.dumps(json.loads(f.read().lower())) + + with open(self.manifest_data, 'r') as f: + manifest_as_json = json.dumps(json.loads(f.read().lower())) + + config_dict_2 = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name, + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': catalog_as_json, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': manifest_as_json + } + conf_2 = ConfigFactory.from_dict(config_dict_2) + extractor_2 = DbtExtractor() + extractor_2.init(Scoped.get_scoped_conf(conf=conf_2, scope=extractor_2.get_scope())) + + result_1 = extractor_1.extract() + result_2 = extractor_2.extract() + self.assertEqual(result_1.name, result_2.name) + self.assertEqual(result_1.description.text, result_2.description.text) + self.assertEqual(result_1.database, result_2.database) + self.assertEqual(result_1.cluster, result_2.cluster) + self.assertEqual(result_1.schema, result_2.schema) + self.assertEqual(result_1.tags, result_2.tags) + self.assertEqual(result_1.is_view, result_2.is_view) + + @no_type_check + def test_keys_retain_original_format(self) -> None: + """ + Test that the database name, cluster name, schema and table name do not + have lowercase auto applied. + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), # Force upper for test + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.FORCE_TABLE_KEY_LOWER}': False + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + result = extractor.extract() + + self.assertEqual(result.name, 'fact_third_party_performance') + self.assertEqual(result.database, 'SNOWFLAKE') + self.assertEqual(result.cluster, 'dbt_demo') + self.assertEqual(result.schema, 'public') + + def test_do_not_extract_tables(self) -> None: + """ + Test that tables are not extracted. + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.EXTRACT_TABLES}': False + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + has_next = True + while has_next: + extraction = extractor.extract() + self.assertFalse(isinstance(extraction, TableMetadata)) + if extraction is None: + break + + def test_do_not_extract_descriptions(self) -> None: + """ + Test that tables are not extracted. + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.EXTRACT_DESCRIPTIONS}': False + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + has_next = True + while has_next: + extraction = extractor.extract() + if isinstance(extraction, TableMetadata): + # No table descriptions + self.assertEqual(extraction.description, None) + + # No column descriptions + for col in extraction.columns: + self.assertEqual(col.description, None) + + if extraction is None: + break + + def test_do_not_extract_dbt_tags(self) -> None: + """ + Test that tags are not extracted as Badges + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.EXTRACT_TAGS}': False + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + has_next = True + while has_next: + extraction = extractor.extract() + self.assertFalse(isinstance(extraction, BadgeMetadata)) + if extraction is None: + break + + def test_import_tags_as_tags(self) -> None: + """ + Test that dbt tags can be configured to be imported as Amundsen tags. + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.IMPORT_TAGS_AS}': 'tag' + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + # The 7th table has tags + extraction = [extractor.extract() for i in range(6)][-1] + self.assertEqual(extraction.tags, ['finance', 'certified']) # type: ignore + + def test_do_not_extract_dbt_lineage(self) -> None: + """ + Test that table level lineage is not extracted from dbt + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.EXTRACT_LINEAGE}': False + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + has_next = True + while has_next: + extraction = extractor.extract() + self.assertFalse(isinstance(extraction, TableLineage)) + if extraction is None: + break + + def test_alias_for_table_name(self) -> None: + """ + Test that table level lineage is not extracted from dbt + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.MODEL_NAME_KEY}': 'alias' + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + result = extractor.extract() + known_alias = 'cost_summary' # One table aliased as "cost_summary" + known_alias_cnt = 0 + while result: + if isinstance(result, TableMetadata): + self.assertNotEqual(result.name, 'fact_daily_expenses') + if result.name == known_alias: + known_alias_cnt += 1 + result = extractor.extract() + self.assertEqual(known_alias_cnt, 1) + + def test_filter_schema_name(self) -> None: + """ + Test that table level lineage is not extracted from dbt + """ + config_dict = { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data, + f'extractor.dbt.{DbtExtractor.EXTRACT_LINEAGE}': False, + f'extractor.dbt.{DbtExtractor.SCHEMA_FILTER}': 'other_schema_value' + } + conf = ConfigFactory.from_dict(config_dict) + extractor = DbtExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + # Tests currently have 1 schema defined + result = extractor.extract() + self.assertEqual(result, None) + + def test_invalid_dbt_inputs(self) -> None: + """ + Test that table level lineage is not extracted from dbt + """ + missing_inputs = [ + { + # f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data + }, + { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + # f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data + }, + { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + # f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data + } + ] + for missing_input_config in missing_inputs: + conf = ConfigFactory.from_dict(missing_input_config) + extractor = DbtExtractor() + with pytest.raises(pyhocon.exceptions.ConfigMissingException): + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + # Invalid manifest.json and invalid catalog.json + invalid_file_jsons = [ + { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': 'not a real file location or json', + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': self.manifest_data + }, + { + f'extractor.dbt.{DbtExtractor.DATABASE_NAME}': self.database_name.upper(), + f'extractor.dbt.{DbtExtractor.CATALOG_JSON}': self.catalog_file_loc, + f'extractor.dbt.{DbtExtractor.MANIFEST_JSON}': 'not a real file location or json' + } + ] + for invalid_conf in invalid_file_jsons: + conf = ConfigFactory.from_dict(invalid_conf) + extractor = DbtExtractor() + with pytest.raises(InvalidDbtInputs): + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) diff --git a/databuilder/tests/unit/extractor/test_deltalake_extractor.py b/databuilder/tests/unit/extractor/test_deltalake_extractor.py new file mode 100644 index 0000000000..a8eaa56af2 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_deltalake_extractor.py @@ -0,0 +1,454 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import tempfile +import unittest +from typing import Dict + +from pyhocon import ConfigFactory +# patch whole class to avoid actually calling for boto3.client during tests +from pyspark.sql import SparkSession +from pyspark.sql.catalog import Table + +from databuilder import Scoped +from databuilder.extractor.delta_lake_metadata_extractor import ( + DeltaLakeMetadataExtractor, ScrapedColumnMetadata, ScrapedTableMetadata, +) +from databuilder.extractor.table_metadata_constants import PARTITION_BADGE +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.watermark import Watermark + + +class TestDeltaLakeExtractor(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.spark = SparkSession.builder \ + .appName("Amundsen Delta Lake Metadata Extraction") \ + .master("local") \ + .config("spark.jars.packages", "io.delta:delta-core_2.12:0.7.0") \ + .config("spark.sql.warehouse.dir", tempfile.TemporaryDirectory()) \ + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \ + .config("spark.driver.host", "127.0.0.1") \ + .config("spark.driver.bindAddress", "127.0.0.1") \ + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \ + .getOrCreate() + self.config_dict = { + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.CLUSTER_KEY}': 'test_cluster', + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.SCHEMA_LIST_KEY}': [], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.EXCLUDE_LIST_SCHEMAS_KEY}': [], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.DATABASE_KEY}': 'test_database', + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.DELTA_TABLES_ONLY}': False + } + conf = ConfigFactory.from_dict(self.config_dict) + self.dExtractor = DeltaLakeMetadataExtractor() + self.dExtractor.init(Scoped.get_scoped_conf(conf=conf, scope=self.dExtractor.get_scope())) + self.dExtractor.set_spark(self.spark) + self.setUpSchemas() + + def setUpSchemas(self) -> None: + self.spark.sql("create schema if not exists test_schema1") + self.spark.sql("create schema if not exists test_schema2") + self.spark.sql("create table if not exists test_schema1.test_table1 (a string, b int) using delta") + self.spark.sql("create table if not exists " + "test_schema1.test_table3 (c boolean, d float) using delta partitioned by (c)") + self.spark.sql("create table if not exists test_schema2.test_parquet (a string) using parquet") + self.spark.sql("create table if not exists test_schema2.test_table2 (a2 string, b2 double) using delta") + # TODO do we even need to support views and none delta tables in this case? + self.spark.sql("create view if not exists test_schema2.test_view1 as (select * from test_schema2.test_table2)") + + self.spark.sql("create table if not exists " + "test_schema2.watermarks_single_partition (date date, value float) using delta partitioned by" + "(date)") + self.spark.sql("insert into test_schema2.watermarks_single_partition values " + "('2020-12-03', 1337), ('2020-12-02', 42), ('2020-12-01', 42), ('2020-12-05', 42)," + "('2020-12-04', 42)") + self.spark.sql("create table if not exists " + "test_schema2.watermarks_multi_partition (date date, spec int, value float) using delta " + "partitioned by (date, spec)") + self.spark.sql("insert into test_schema2.watermarks_multi_partition values " + "('2020-12-03', 1, 1337), ('2020-12-02', 2, 42), ('2020-12-01', 2, 42), ('2020-12-05', 3, 42)," + "('2020-12-04', 1, 42)") + # Nested/Complex schemas + self.spark.sql("create schema if not exists complex_schema") + self.spark.sql("create table if not exists complex_schema.struct_table (a int, struct_col struct) using delta") + self.spark.sql("create table if not exists complex_schema.nested_struct_table (a int, struct_col " + "struct>) using delta") + self.spark.sql("create table if not exists complex_schema.array_table (a int, arr_col array) using " + "delta") + self.spark.sql("create table if not exists complex_schema.array_complex_elem_table (a int, arr_col " + "array>) using delta") + self.spark.sql("create table if not exists complex_schema.map_table (a int, map_col map) using " + "delta") + self.spark.sql("create table if not exists complex_schema.map_complex_key_table (a int, map_col " + "map>,e:double>,int>) using delta") + self.spark.sql("create table if not exists complex_schema.map_complex_value_table (a int, map_col map>,e:double>>) using delta") + self.spark.sql("create table if not exists complex_schema.map_complex_key_and_value_table (a int, map_col " + "map>,e:double>,struct>," + "i:double>>) using delta") + + self.spark.sql("create table if not exists complex_schema.array_of_array (a array>) using " + "delta") + self.spark.sql("create table if not exists complex_schema.map_of_map (a map>) using " + "delta") + self.spark.sql("create table if not exists complex_schema.map_of_array_of_structs (a map>>) using delta") + + def test_get_all_schemas(self) -> None: + '''Tests getting all schemas''' + actual_schemas = self.dExtractor.get_schemas([]) + self.assertEqual(["complex_schema", "default", "test_schema1", "test_schema2"], actual_schemas) + + def test_get_all_schemas_with_exclude(self) -> None: + '''Tests the exclude list''' + actual_schemas = self.dExtractor.get_schemas(["complex_schema", "default"]) + self.assertEqual(["test_schema1", "test_schema2"], actual_schemas) + + def test_get_all_tables(self) -> None: + '''Tests table fetching''' + actual = [x.name for x in self.dExtractor.get_all_tables(["test_schema1", "default"])] + self.assertEqual(["test_table1", "test_table3"], actual) + + def test_scrape_table_detail(self) -> None: + '''Test Table Detail Scraping''' + actual = self.dExtractor.scrape_table_detail("test_schema1.test_table1") + expected: Dict = {'createdAt': None, + 'description': None, + 'format': 'delta', + 'id': None, + 'lastModified': None, + 'location': None, + 'minReaderVersion': 1, + 'minWriterVersion': 2, + 'name': 'test_schema1.test_table1', + 'numFiles': 0, + 'partitionColumns': [], + 'properties': {}, + 'sizeInBytes': 0} + self.assertIsNotNone(actual) + if actual: + self.assertEqual(actual.keys(), expected.keys()) + self.assertEqual(actual['name'], expected['name']) + self.assertEqual(actual['format'], expected['format']) + + def test_scrape_view_detail(self) -> None: + actual = self.dExtractor.scrape_view_detail("test_schema2.test_view1") + self.assertIsNotNone(actual) + expected = {'Created By': 'Spark 3.0.1', + 'Created Time': None, + 'Database': 'test_schema2', + 'Last Access': 'UNKNOWN', + 'Table': 'test_view1', + 'Table Properties': '[view.catalogAndNamespace.numParts=2, ' + 'view.query.out.col.0=a2, view.query.out.numCols=2, ' + 'view.query.out.col.1=b2, ' + 'view.catalogAndNamespace.part.0=spark_catalog, ' + 'view.catalogAndNamespace.part.1=default]', + 'Type': 'VIEW', + 'View Catalog and Namespace': 'spark_catalog.default', + 'View Original Text': '(select * from test_schema2.test_table2)', + 'View Query Output Columns': '[a2, b2]', + 'View Text': '(select * from test_schema2.test_table2)'} + if actual: + actual['Created Time'] = None + self.assertEqual(actual, expected) + + def test_fetch_partitioned_delta_columns(self) -> None: + actual = self.dExtractor.fetch_columns("test_schema1", "test_table3") + partition_column = ScrapedColumnMetadata(name="c", description=None, data_type="boolean", sort_order=0) + partition_column.set_is_partition(True) + partition_column.set_badges([PARTITION_BADGE]) + expected = [partition_column, + ScrapedColumnMetadata(name="d", description=None, data_type="float", sort_order=1)] + for a, b in zip(actual, expected): + self.assertEqual(a, b) + + def test_fetch_delta_columns(self) -> None: + actual = self.dExtractor.fetch_columns("test_schema1", "test_table1") + expected = [ScrapedColumnMetadata(name="a", description=None, data_type="string", sort_order=0), + ScrapedColumnMetadata(name="b", description=None, data_type="int", sort_order=1)] + for a, b in zip(actual, expected): + self.assertEqual(a, b) + + def test_fetch_delta_columns_failure(self) -> None: + actual = self.dExtractor.fetch_columns("test_schema1", "nonexistent_table") + self.assertEquals(actual, []) + + def test_scrape_tables(self) -> None: + table = Table(name="test_table1", database="test_schema1", description=None, + tableType="delta", isTemporary=False) + actual = self.dExtractor.scrape_table(table) + + expected = ScrapedTableMetadata(schema="test_schema1", table="test_table1") + expected.set_columns([ScrapedColumnMetadata(name="a", description=None, data_type="string", sort_order=0), + ScrapedColumnMetadata(name="b", description=None, data_type="int", sort_order=1)]) + if actual is not None: + self.assertEqual(expected.schema, actual.schema) + self.assertEqual(expected.table, actual.table) + self.assertEqual(expected.columns, actual.columns) + self.assertEqual(expected.failed_to_scrape, actual.failed_to_scrape) + self.assertEqual(expected.is_view, actual.is_view) + self.assertIsNotNone(actual.table_detail) + else: + self.assertIsNotNone(actual) + + def test_create_table_metadata(self) -> None: + scraped = ScrapedTableMetadata(schema="test_schema1", table="test_table1") + scraped.set_columns([ScrapedColumnMetadata(name="a", description=None, data_type="string", sort_order=0), + ScrapedColumnMetadata(name="b", description=None, data_type="int", sort_order=1)]) + created_metadata = self.dExtractor.create_table_metadata(scraped) + expected = TableMetadata("test_database", "test_cluster", "test_schema1", "test_table1", description=None, + columns=[ColumnMetadata("a", None, "string", 0), + ColumnMetadata("b", None, "int", 1)]) + self.assertEqual(str(expected), str(created_metadata)) + + def test_create_last_updated(self) -> None: + scraped_table = self.dExtractor.scrape_table(Table("test_table1", "test_schema1", None, "delta", False)) + actual_last_updated = None + if scraped_table: + actual_last_updated = self.dExtractor.create_table_last_updated(scraped_table) + self.assertIsNotNone(actual_last_updated) + + def test_extract(self) -> None: + ret = [] + data = self.dExtractor.extract() + while data is not None: + ret.append(data) + data = self.dExtractor.extract() + self.assertEqual(len(ret), 40) + + def test_extract_with_only_specific_schemas(self) -> None: + self.config_dict = { + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.CLUSTER_KEY}': 'test_cluster', + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.SCHEMA_LIST_KEY}': ['test_schema2'], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.EXCLUDE_LIST_SCHEMAS_KEY}': [], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.DATABASE_KEY}': 'test_database' + } + conf = ConfigFactory.from_dict(self.config_dict) + self.dExtractor.init(Scoped.get_scoped_conf(conf=conf, + scope=self.dExtractor.get_scope())) + ret = [] + data = self.dExtractor.extract() + while data is not None: + ret.append(data) + data = self.dExtractor.extract() + self.assertEqual(len(ret), 12) + + def test_extract_when_excluding(self) -> None: + self.config_dict = { + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.CLUSTER_KEY}': 'test_cluster', + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.SCHEMA_LIST_KEY}': [], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.EXCLUDE_LIST_SCHEMAS_KEY}': + ['test_schema2'], + f'extractor.delta_lake_table_metadata.{DeltaLakeMetadataExtractor.DATABASE_KEY}': 'test_database' + } + conf = ConfigFactory.from_dict(self.config_dict) + self.dExtractor.init(Scoped.get_scoped_conf(conf=conf, + scope=self.dExtractor.get_scope())) + ret = [] + data = self.dExtractor.extract() + while data is not None: + ret.append(data) + data = self.dExtractor.extract() + self.assertEqual(len(ret), 26) + + def test_table_does_not_exist(self) -> None: + table = Table(name="test_table5", database="test_schema1", description=None, + tableType="delta", isTemporary=False) + actual = self.dExtractor.scrape_table(table) + self.assertIsNone(actual) + + def test_scrape_all_tables(self) -> None: + tables = [Table(name="test_table1", database="test_schema1", description=None, + tableType="delta", isTemporary=False), + Table(name="test_table3", database="test_schema1", description=None, + tableType="delta", isTemporary=False)] + actual = self.dExtractor.scrape_all_tables(tables) + self.assertEqual(2, len(actual)) + + def test_scrape_complex_schema_no_config(self) -> None: + # Don't set the extract_nested_columns config to verify backwards compatibility + actual = self.dExtractor.fetch_columns(schema="complex_schema", table="struct_table") + self.assertEqual(actual, [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="struct_col", description=None, data_type="struct", + sort_order=1), + ]) + + def test_scrape_complex_schema_columns(self) -> None: + # set extract_nested_columns to True to test complex column extraction + self.dExtractor.extract_nested_columns = True + expected_dict = { + "struct_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="struct_col", description=None, data_type="struct", + sort_order=1), + ScrapedColumnMetadata(name="struct_col.b", description=None, data_type="string", sort_order=2), + ScrapedColumnMetadata(name="struct_col.c", description=None, data_type="double", sort_order=3), + ], + "nested_struct_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="struct_col", description=None, + data_type="struct>", sort_order=1), + ScrapedColumnMetadata(name="struct_col.nested_struct_col", description=None, + data_type="struct", sort_order=2), + ScrapedColumnMetadata(name="struct_col.nested_struct_col.b", description=None, data_type="string", + sort_order=3), + ScrapedColumnMetadata(name="struct_col.nested_struct_col.c", description=None, data_type="double", + sort_order=4), + ], + "array_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="arr_col", description=None, data_type="array", sort_order=1), + ], + "array_complex_elem_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="arr_col", description=None, data_type="array>", + sort_order=1), + ScrapedColumnMetadata(name="arr_col.b", description=None, data_type="string", sort_order=2), + ScrapedColumnMetadata(name="arr_col.c", description=None, data_type="double", sort_order=3), + ], + "map_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="map_col", description=None, data_type="map", sort_order=1), + ], + "map_complex_key_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="map_col", description=None, + data_type="map>,e:double>,int>", + sort_order=1), + ], + "map_complex_value_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="map_col", description=None, data_type="map>,e:double>>", + sort_order=1), + ScrapedColumnMetadata(name="map_col.b", description=None, data_type="array>", + sort_order=2), + ScrapedColumnMetadata(name="map_col.b.c", description=None, data_type="int", sort_order=3), + ScrapedColumnMetadata(name="map_col.b.d", description=None, data_type="string", sort_order=4), + ScrapedColumnMetadata(name="map_col.e", description=None, data_type="double", sort_order=5), + ], + "map_complex_key_and_value_table": [ + ScrapedColumnMetadata(name="a", description=None, data_type="int", sort_order=0), + ScrapedColumnMetadata(name="map_col", description=None, + data_type="map>,e:double>," + "struct>,i:double>>", + sort_order=1), + ScrapedColumnMetadata(name="map_col.f", description=None, data_type="array>", + sort_order=2), + ScrapedColumnMetadata(name="map_col.f.g", description=None, data_type="int", sort_order=3), + ScrapedColumnMetadata(name="map_col.f.h", description=None, data_type="string", sort_order=4), + ScrapedColumnMetadata(name="map_col.i", description=None, data_type="double", sort_order=5), + ], + "array_of_array": [ + ScrapedColumnMetadata(name="a", description=None, data_type="array>", sort_order=0), + ], + "map_of_map": [ + ScrapedColumnMetadata(name="a", description=None, data_type="map>", sort_order=0), + ], + "map_of_array_of_structs": [ + ScrapedColumnMetadata(name="a", description=None, data_type="map>>", + sort_order=0), + ScrapedColumnMetadata(name="a.b", description=None, data_type="int", sort_order=1), + ScrapedColumnMetadata(name="a.c", description=None, data_type="string", sort_order=2), + ] + } + for table_name, expected in expected_dict.items(): + actual = self.dExtractor.fetch_columns(schema="complex_schema", table=table_name) + self.assertEqual(len(expected), len(actual), f"{table_name} failed") + self.assertListEqual(expected, actual, f"{table_name} failed") + + def test_create_table_watermarks_single_partition(self) -> None: + scraped_table = self.dExtractor.scrape_table(Table("watermarks_single_partition", "test_schema2", None, "delta", + False)) + self.assertIsNotNone(scraped_table) + if scraped_table: + found = self.dExtractor.create_table_watermarks(scraped_table) + self.assertIsNotNone(found) + if found: + self.assertEqual(1, len(found)) + self.assertEqual(2, len(found[0])) + create_time = found[0][0].create_time + expected = [( + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_single_partition', + part_name='date=2020-12-05', + part_type='high_watermark', + cluster='test_cluster'), + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_single_partition', + part_name='date=2020-12-01', + part_type='low_watermark', + cluster='test_cluster') + )] + self.assertEqual(str(expected), str(found)) + + def test_create_table_watermarks_multi_partition(self) -> None: + scraped_table = self.dExtractor.scrape_table(Table("watermarks_multi_partition", "test_schema2", None, "delta", + False)) + self.assertIsNotNone(scraped_table) + if scraped_table: + found = self.dExtractor.create_table_watermarks(scraped_table) + self.assertIsNotNone(found) + if found: + self.assertEqual(2, len(found)) + self.assertEqual(2, len(found[0])) + create_time = found[0][0].create_time + expected = [( + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_multi_partition', + part_name='date=2020-12-05', + part_type='high_watermark', + cluster='test_cluster'), + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_multi_partition', + part_name='date=2020-12-01', + part_type='low_watermark', + cluster='test_cluster') + ), ( + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_multi_partition', + part_name='spec=3', + part_type='high_watermark', + cluster='test_cluster'), + Watermark( + create_time=create_time, + database='test_database', + schema='test_schema2', + table_name='watermarks_multi_partition', + part_name='spec=1', + part_type='low_watermark', + cluster='test_cluster') + )] + self.assertEqual(str(expected), str(found)) + + def test_create_table_watermarks_without_partition(self) -> None: + scraped_table = self.dExtractor.scrape_table(Table("test_table1", "test_schema1", None, "delta", False)) + self.assertIsNotNone(scraped_table) + if scraped_table: + found = self.dExtractor.create_table_watermarks(scraped_table) + self.assertIsNone(found) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_dremio_metadata_extractor.py b/databuilder/tests/unit/extractor/test_dremio_metadata_extractor.py new file mode 100644 index 0000000000..9b70db324c --- /dev/null +++ b/databuilder/tests/unit/extractor/test_dremio_metadata_extractor.py @@ -0,0 +1,110 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import ( + Any, Dict, List, +) +from unittest.mock import MagicMock, patch + +from pyhocon import ConfigFactory + +from databuilder.extractor.dremio_metadata_extractor import DremioMetadataExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestDremioMetadataExtractor(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict: Dict[str, str] = {} + + self.conf = ConfigFactory.from_dict(config_dict) + + @patch('databuilder.extractor.dremio_metadata_extractor.connect') + def test_extraction_with_empty_query_result(self, mock_connect: MagicMock) -> None: + """ + Test Extraction with empty result from query + """ + mock_connection = MagicMock() + mock_connect.return_value = mock_connection + + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + + extractor = DremioMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + @patch('databuilder.extractor.dremio_metadata_extractor.connect') + def test_extraction_with_single_result(self, mock_connect: MagicMock) -> None: + """ + Test Extraction with single table result from query + """ + mock_connection = MagicMock() + mock_connect.return_value = mock_connection + + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + + mock_execute = MagicMock() + mock_cursor.execute = mock_execute + + mock_cursor.description = [ + ['col_name'], + ['col_description'], + ['col_type'], + ['col_sort_order'], + ['database'], + ['cluster'], + ['schema'], + ['name'], + ['description'], + ['is_view'] + ] + + # Pass flake8 Unsupported operand types for + error + table: List[Any] = [ + 'DREMIO', + 'Production', + 'test_schema', + 'test_table', + 'a table for testing', + 'false' + ] + + # Pass flake8 Unsupported operand types for + error + expected_input: List[List[Any]] = [ + ['col_id1', 'description of id1', 'number', 0] + table, + ['col_id2', 'description of id2', 'number', 1] + table, + ['is_active', None, 'boolean', 2] + table, + ['source', 'description of source', 'varchar', 3] + table, + ['etl_created_at', 'description of etl_created_at', 'timestamp_ltz', 4] + table, + ['ds', None, 'varchar', 5] + table + ] + + mock_cursor.execute.return_value = expected_input + + extractor = DremioMetadataExtractor() + extractor.init(self.conf) + + actual = extractor.extract() + expected = TableMetadata('DREMIO', 'Production', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'number', 0), + ColumnMetadata('col_id2', 'description of id2', 'number', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', + 'timestamp_ltz', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_es_column_stats_extractor.py b/databuilder/tests/unit/extractor/test_es_column_stats_extractor.py new file mode 100644 index 0000000000..7fc96fb28e --- /dev/null +++ b/databuilder/tests/unit/extractor/test_es_column_stats_extractor.py @@ -0,0 +1,199 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any, Dict + +from elasticsearch import Elasticsearch +from mock import MagicMock +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.es_column_stats_extractor import ElasticsearchColumnStatsExtractor +from databuilder.models.table_stats import TableColumnStats + + +class TestElasticsearchColumnStatsExtractor(unittest.TestCase): + es_version_v6 = '6.0.0' + + es_version_v7 = '7.0.0' + + indices_v6 = { + '.technical_index': { + 'mappings': { + 'doc': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + }, + 'proper_index': { + 'mappings': { + 'doc': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + indices_v7 = { + '.technical_index': { + 'mappings': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + }, + 'proper_index': { + 'mappings': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + stats_v6 = { + 'aggregations': { + 'stats': { + 'fields': [ + { + 'name': 'long_property', + 'avg': 5, + 'sum': 10, + 'count': 2 + } + ] + } + } + } + + stats_v7 = { + 'aggregations': { + 'stats': { + 'fields': [ + { + 'name': 'long_property', + 'avg': 5, + 'sum': 10, + 'count': 2 + } + ] + } + } + } + + def setUp(self) -> None: + params = {'extractor.es_column_stats.schema': 'schema_name', + 'extractor.es_column_stats.cluster': 'cluster_name', + 'extractor.es_column_stats.client': Elasticsearch()} + + config = ConfigFactory.from_dict(params) + + self.config = config + + def _get_extractor(self) -> Any: + extractor = ElasticsearchColumnStatsExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + return extractor + + def _test_extractor_without_technical_data(self, es_version: str, indices: Dict, stats: Dict) -> None: + extractor = self._get_extractor() + + extractor._get_es_version = lambda: es_version + extractor.es.indices.get = MagicMock(return_value=indices) + extractor.es.search = MagicMock(return_value=stats) + + common = { + 'db': 'elasticsearch', + 'schema': 'schema_name', + 'table_name': 'proper_index', + 'cluster': 'cluster_name', + 'start_epoch': '0', + 'end_epoch': '0' + } + + compare_params = {'table', 'schema', 'db', 'col_name', 'start_epoch', + 'end_epoch', 'cluster', 'stat_type', 'stat_val'} + expected = [ + {x: spec[x] for x in compare_params if x in spec} for spec in + [ + TableColumnStats( + **{**dict(stat_name='avg', stat_val='5', col_name='long_property'), **common}).__dict__, + TableColumnStats( + **{**dict(stat_name='sum', stat_val='10', col_name='long_property'), **common}).__dict__, + TableColumnStats( + **{**dict(stat_name='count', stat_val='2', col_name='long_property'), **common}).__dict__, + ] + ] + + result = [] + + while True: + stat = extractor.extract() + + if stat: + result.append(stat) + else: + break + + result_spec = [{x: spec.__dict__[x] for x in compare_params if x in spec.__dict__} for spec in result] + + for r in result: + self.assertIsInstance(r, TableColumnStats) + + self.assertListEqual(expected, result_spec) + + def test_extractor_without_technical_data_v6(self) -> None: + self._test_extractor_without_technical_data(self.es_version_v6, self.indices_v6, self.stats_v6) + + def test_extractor_without_technical_data_v7(self) -> None: + self._test_extractor_without_technical_data(self.es_version_v7, self.indices_v7, self.stats_v7) diff --git a/databuilder/tests/unit/extractor/test_es_last_updated_extractor.py b/databuilder/tests/unit/extractor/test_es_last_updated_extractor.py new file mode 100644 index 0000000000..f96086e095 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_es_last_updated_extractor.py @@ -0,0 +1,34 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.es_last_updated_extractor import EsLastUpdatedExtractor + + +class TestEsLastUpdatedExtractor(unittest.TestCase): + + def setUp(self) -> None: + config_dict = { + 'extractor.es_last_updated.model_class': + 'databuilder.models.es_last_updated.ESLastUpdated', + } + self.conf = ConfigFactory.from_dict(config_dict) + + @patch('time.time') + def test_extraction_with_model_class(self, mock_time: Any) -> None: + """ + Test Extraction using model class + """ + mock_time.return_value = 10000000 + extractor = EsLastUpdatedExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.timestamp, 10000000) diff --git a/databuilder/tests/unit/extractor/test_es_metadata_extractor.py b/databuilder/tests/unit/extractor/test_es_metadata_extractor.py new file mode 100644 index 0000000000..d52517915f --- /dev/null +++ b/databuilder/tests/unit/extractor/test_es_metadata_extractor.py @@ -0,0 +1,317 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any + +from elasticsearch import Elasticsearch +from mock import MagicMock +from pyhocon import ConfigFactory, ConfigTree + +from databuilder import Scoped +from databuilder.extractor.es_metadata_extractor import ElasticsearchMetadataExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestElasticsearchIndexExtractor(unittest.TestCase): + config_not_sorted = ConfigFactory.from_dict({ + 'extractor.es_metadata.schema': 'schema_name', + 'extractor.es_metadata.cluster': 'cluster_name', + 'extractor.es_metadata.correct_sort_order': False, + 'extractor.es_metadata.client': Elasticsearch()}) + + config_sorted = ConfigFactory.from_dict({ + 'extractor.es_metadata.schema': 'schema_name', + 'extractor.es_metadata.cluster': 'cluster_name', + 'extractor.es_metadata.correct_sort_order': True, + 'extractor.es_metadata.client': Elasticsearch()}) + + es_version_v6 = '6.0.0' + + es_version_v7 = '7.0.0' + + indices_v6 = { + '.technical_index': { + 'mappings': { + 'doc': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + }, + 'proper_index': { + 'mappings': { + 'doc': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + indices_v7 = { + '.technical_index': { + 'mappings': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + }, + 'proper_index': { + 'mappings': { + 'properties': { + 'keyword_property': { + 'type': 'keyword' + }, + 'long_property': { + 'type': 'long' + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + indices_hierarchical_v6 = { + 'proper_index': { + 'mappings': { + 'doc': { + 'properties': { + 'keyword': { + 'properties': { + 'inner': { + 'type': 'keyword' + } + } + }, + 'long': { + 'properties': { + 'inner': { + 'type': 'long' + } + } + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + indices_hierarchical_v7 = { + 'proper_index': { + 'mappings': { + 'properties': { + 'keyword': { + 'properties': { + 'inner': { + 'type': 'keyword' + } + } + }, + 'long': { + 'properties': { + 'inner': { + 'type': 'long' + } + } + } + } + }, + 'aliases': { + 'search_index': {} + }, + 'settings': { + 'number_of_replicas': 1 + } + } + } + + def _get_extractor(self, config: ConfigTree) -> Any: + extractor = ElasticsearchMetadataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=config, scope=extractor.get_scope())) + + return extractor + + def test_extractor_without_technical_data_es_v6(self) -> None: + extractor = self._get_extractor(self.config_not_sorted) + + extractor._get_es_version = lambda: self.es_version_v6 + extractor.es.indices.get = MagicMock(return_value=self.indices_v6) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword_property', '', 'keyword', 0, []), + ColumnMetadata('long_property', '', 'long', 0, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) + + def test_extractor_without_technical_data_es_v7(self) -> None: + extractor = self._get_extractor(self.config_not_sorted) + + extractor._get_es_version = lambda: self.es_version_v7 + extractor.es.indices.get = MagicMock(return_value=self.indices_v7) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword_property', '', 'keyword', 0, []), + ColumnMetadata('long_property', '', 'long', 0, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) + + def test_extractor_hierarchical_es_v6(self) -> None: + extractor = self._get_extractor(self.config_not_sorted) + + extractor._get_es_version = lambda: self.es_version_v6 + extractor.es.indices.get = MagicMock(return_value=self.indices_hierarchical_v6) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword.inner', '', 'keyword', 0, []), + ColumnMetadata('long.inner', '', 'long', 0, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) + + def test_extractor_hierarchical_es_v7(self) -> None: + extractor = self._get_extractor(self.config_not_sorted) + + extractor._get_es_version = lambda: self.es_version_v7 + extractor.es.indices.get = MagicMock(return_value=self.indices_hierarchical_v7) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword.inner', '', 'keyword', 0, []), + ColumnMetadata('long.inner', '', 'long', 0, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) + + def test_extractor_sorted_es_v6(self) -> None: + extractor = self._get_extractor(self.config_sorted) + + extractor._get_es_version = lambda: self.es_version_v6 + extractor.es.indices.get = MagicMock(return_value=self.indices_v6) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword_property', '', 'keyword', 0, []), + ColumnMetadata('long_property', '', 'long', 1, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) + + def test_extractor_sorted_es_v7(self) -> None: + extractor = self._get_extractor(self.config_sorted) + + extractor._get_es_version = lambda: self.es_version_v7 + extractor.es.indices.get = MagicMock(return_value=self.indices_v7) + + expected = TableMetadata('elasticsearch', 'cluster_name', 'schema_name', 'proper_index', + None, [ColumnMetadata('keyword_property', '', 'keyword', 0, []), + ColumnMetadata('long_property', '', 'long', 1, [])], False, []) + + result = [] + + while True: + entry = extractor.extract() + + if entry: + result.append(entry) + else: + break + + self.assertEqual(1, len(result)) + self.assertEqual(expected.__repr__(), result[0].__repr__()) diff --git a/databuilder/tests/unit/extractor/test_es_watermark_extractor.py b/databuilder/tests/unit/extractor/test_es_watermark_extractor.py new file mode 100644 index 0000000000..84581cce5d --- /dev/null +++ b/databuilder/tests/unit/extractor/test_es_watermark_extractor.py @@ -0,0 +1,202 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from datetime import datetime +from typing import ( + Any, Dict, List, +) + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.base_extractor import Extractor +from databuilder.extractor.es_watermark_extractor import ElasticsearchWatermarkExtractor +from databuilder.models.watermark import Watermark + + +class TestElasticsearchWatermarkBlizzExtractor(unittest.TestCase): + # Index names + index_with_no_data = 'index_with_no_data' + index_with_data_1 = 'index_with_data_1' + index_with_data_2 = 'index_with_data_2' + + # Meta + indices_meta = { + index_with_no_data: { + 'settings': { + 'index': { + 'creation_date': '1641861298000' + } + } + }, + index_with_data_1: { + 'settings': { + 'index': { + 'creation_date': '1641863003000' + } + } + }, + index_with_data_2: { + 'settings': { + 'index': { + 'creation_date': '1641949455000' + } + } + } + } + + # Watermarks + indices_watermarks = { + index_with_no_data: { + 'aggregations': { + 'min_watermark': { + 'value': None + }, + 'max_watermark': { + 'value': None + } + } + }, + index_with_data_1: { + 'aggregations': { + 'min_watermark': { + 'value': 1641863055000 + }, + 'max_watermark': { + 'value': 1641949455000 + } + } + }, + index_with_data_2: { + 'aggregations': { + 'min_watermark': { + 'value': 1641949455000 + }, + 'max_watermark': { + 'value': 1642126450000 + } + } + } + } + + class MockElasticsearch: + def __init__(self, indices: Dict, indices_watermarks: Dict) -> None: + self.indices = {'*': indices} + self.indices_watermarks = indices_watermarks + + def search(self, index: str, size: int, aggs: Dict) -> Any: + return self.indices_watermarks[index] + + def _get_indices_meta(self, index_names: List[str]) -> Dict: + indices_meta = {} + for index_name in index_names: + indices_meta[index_name] = self.indices_meta[index_name] + + return indices_meta + + def _get_config(self, index_names: List[str]) -> Any: + return ConfigFactory.from_dict({ + 'extractor.es_watermark.schema': 'schema_name', + 'extractor.es_watermark.cluster': 'cluster_name', + 'extractor.es_watermark.time_field': 'time', + 'extractor.es_watermark.client': self.MockElasticsearch( + self._get_indices_meta(index_names), + self.indices_watermarks + )}) + + def _get_extractor(self, index_names: List[str]) -> Any: + extractor = ElasticsearchWatermarkExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self._get_config(index_names), scope=extractor.get_scope())) + + return extractor + + def _extract_and_compare(self, extractor: Extractor, expected: List[Watermark]) -> None: + result = [] + + while True: + entry = extractor.extract() + if entry: + result.append(entry) + else: + break + + self.assertEqual(len(expected), len(result)) + for idx in range(len(expected)): + self.assertEqual(expected[idx].__repr__(), result[idx].__repr__()) + + def test_no_indices(self) -> None: + extractor = self._get_extractor([]) + expected: List[Watermark] = [] + self._extract_and_compare(extractor, expected) + + def test_index_with_no_data(self) -> None: + extractor = self._get_extractor([self.index_with_no_data]) + expected: List[Watermark] = [] + self._extract_and_compare(extractor, expected) + + def test_index_with_data(self) -> None: + extractor = self._get_extractor([self.index_with_data_1]) + expected = [ + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_1', + create_time=datetime.fromtimestamp(1641863003).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1641863055).strftime('%Y-%m-%d')}", + part_type='low_watermark' + ), + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_1', + create_time=datetime.fromtimestamp(1641863003).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1641949455).strftime('%Y-%m-%d')}", + part_type='high_watermark' + ) + ] + self._extract_and_compare(extractor, expected) + + def test_indices_with_and_without_data(self) -> None: + extractor = self._get_extractor([self.index_with_no_data, self.index_with_data_1, self.index_with_data_2]) + expected = [ + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_1', + create_time=datetime.fromtimestamp(1641863003).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1641863055).strftime('%Y-%m-%d')}", + part_type='low_watermark' + ), + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_1', + create_time=datetime.fromtimestamp(1641863003).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1641949455).strftime('%Y-%m-%d')}", + part_type='high_watermark' + ), + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_2', + create_time=datetime.fromtimestamp(1641949455).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1641949455).strftime('%Y-%m-%d')}", + part_type='low_watermark' + ), + Watermark( + database='elasticsearch', + cluster='cluster_name', + schema='schema_name', + table_name='index_with_data_2', + create_time=datetime.fromtimestamp(1641949455).strftime('%Y-%m-%d %H:%M:%S'), + part_name=f"time={datetime.fromtimestamp(1642126450).strftime('%Y-%m-%d')}", + part_type='high_watermark' + ) + ] + self._extract_and_compare(extractor, expected) diff --git a/databuilder/tests/unit/extractor/test_eventbridge_extractor.py b/databuilder/tests/unit/extractor/test_eventbridge_extractor.py new file mode 100644 index 0000000000..87fb5cd23c --- /dev/null +++ b/databuilder/tests/unit/extractor/test_eventbridge_extractor.py @@ -0,0 +1,454 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import unittest +from typing import ( + Any, Dict, List, +) + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder.extractor.eventbridge_extractor import EventBridgeExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +registry_name = "TestAmundsen" + +test_schema_openapi_3 = { + "openapi": "3.0.0", + "info": {"version": "1.0.0", "title": "OrderConfirmed"}, + "paths": {}, + "components": { + "schemas": { + "AWSEvent": { + "type": "object", + "required": [ + "detail-type", + "resources", + "detail", + "id", + "source", + "time", + "region", + "account", + ], + "properties": { + "detail": {"$ref": "#/components/schemas/OrderConfirmed"}, + "account": {"type": "string"}, + "detail-type": {"type": "string"}, + "id": {"type": "string"}, + "region": {"type": "string"}, + "resources": {"type": "array", "items": {"type": "string"}}, + "source": {"type": "string"}, + "time": {"type": "string", "format": "date-time"}, + }, + }, + "OrderConfirmed": { + "type": "object", + "properties": { + "id": {"type": "number", "format": "int64"}, + "status": {"type": "string"}, + "currency": {"type": "string"}, + "customer": {"$ref": "#/components/schemas/Customer"}, + "items": { + "type": "array", + "items": {"$ref": "#/components/schemas/Item"}, + }, + }, + }, + "Customer": { + "type": "object", + "properties": { + "firstName": {"type": "string"}, + "lastName": {"type": "string"}, + "email": {"type": "string"}, + "phone": {}, + }, + "description": "customer description", + }, + "Item": { + "type": "object", + "properties": { + "sku": {"type": "number", "format": "int64"}, + "name": {"type": "string"}, + "price": {"type": "number", "format": "double"}, + "quantity": {"type": "number", "format": "int32"}, + }, + }, + "PrimitiveSchema": {"type": "bool"}, + } + }, +} + +openapi_3_item_type = ( + "struct" +) +openapi_3_customer_type = ( + "struct" +) +openapi_3_order_confirmed_type = ( + f"struct>" +) + +expected_openapi_3_tables = [ + TableMetadata( + "eventbridge", + registry_name, + "OrderConfirmed", + "AWSEvent", + None, + [ + ColumnMetadata("detail", None, openapi_3_order_confirmed_type, 0), + ColumnMetadata("account", None, "string", 1), + ColumnMetadata("detail-type", None, "string", 2), + ColumnMetadata("id", None, "string", 3), + ColumnMetadata("region", None, "string", 4), + ColumnMetadata("resources", None, "array", 5), + ColumnMetadata("source", None, "string", 6), + ColumnMetadata("time", None, "string[date-time]", 7), + ], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "OrderConfirmed", + "OrderConfirmed", + None, + [ + ColumnMetadata("id", None, "number[int64]", 0), + ColumnMetadata("status", None, "string", 1), + ColumnMetadata("currency", None, "string", 2), + ColumnMetadata( + "customer", "customer description", openapi_3_customer_type, 3, + ), + ColumnMetadata("items", None, f"array<{openapi_3_item_type}>", 4), + ], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "OrderConfirmed", + "Customer", + None, + [ + ColumnMetadata("firstName", None, "string", 0), + ColumnMetadata("lastName", None, "string", 1), + ColumnMetadata("email", None, "string", 2), + ColumnMetadata("phone", None, "object", 3), + ], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "OrderConfirmed", + "Item", + None, + [ + ColumnMetadata("sku", None, "number[int64]", 0), + ColumnMetadata("name", None, "string", 1), + ColumnMetadata("price", None, "number[double]", 2), + ColumnMetadata("quantity", None, "number[int32]", 3), + ], + False, + ), +] + +test_schema_json_draft_4 = { + "$schema": "http://json-schema.org/draft-04/schema#", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "required": [ + "version", + "id", + "detail-type", + "source", + "account", + "time", + "region", + "resources", + "detail", + ], + "definitions": { + "BookingDone": { + "type": "object", + "properties": {"booking": {"$ref": "#/definitions/Booking"}}, + }, + "Booking": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "status": {"type": "string"}, + "customer": {"$ref": "#/definitions/Customer"}, + }, + "required": ["id", "status", "customer"], + }, + "Customer": { + "type": "object", + "properties": {"id": {"type": "string"}, "name": {"type": "string"}}, + "required": ["id", "name"], + }, + }, + "properties": { + "version": { + "$id": "#/properties/version", + "type": "string", + "description": "version description", + }, + "id": {"$id": "#/properties/id", "type": "string"}, + "detail-type": {"$id": "#/properties/detail-type", "type": "string"}, + "source": {"$id": "#/properties/source", "type": "string"}, + "account": {"$id": "#/properties/account", "type": "string"}, + "time": {"$id": "#/properties/time", "type": "string"}, + "region": {"$id": "#/properties/region", "type": "string"}, + "resources": { + "$id": "#/properties/resources", + "type": "array", + "additionalItems": True, + "items": {"$id": "#/properties/resources/items", "type": "string"}, + }, + "detail": {"$ref": "#/definitions/BookingDone"}, + }, +} + +json_draft_4_customer_type = "struct" +json_draft_4_booking_type = ( + f"struct" +) +json_draft_4_booking_done_type = f"struct" + +expected_json_draft_4_tables = [ + TableMetadata( + "eventbridge", + registry_name, + "The root schema", + "BookingDone", + None, + [ColumnMetadata("booking", None, json_draft_4_booking_type, 0)], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "The root schema", + "Booking", + None, + [ + ColumnMetadata("id", None, "string", 0), + ColumnMetadata("status", None, "string", 1), + ColumnMetadata("customer", None, json_draft_4_customer_type, 2), + ], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "The root schema", + "Customer", + None, + [ + ColumnMetadata("id", None, "string", 0), + ColumnMetadata("name", None, "string", 1), + ], + False, + ), + TableMetadata( + "eventbridge", + registry_name, + "The root schema", + "Root", + "The root schema comprises the entire JSON document.", + [ + ColumnMetadata("version", "version description", "string", 0), + ColumnMetadata("id", None, "string", 1), + ColumnMetadata("detail-type", None, "string", 2), + ColumnMetadata("source", None, "string", 3), + ColumnMetadata("account", None, "string", 4), + ColumnMetadata("time", None, "string", 5), + ColumnMetadata("region", None, "string", 6), + ColumnMetadata("resources", None, "array", 7), + ColumnMetadata("detail", None, json_draft_4_booking_done_type, 8), + ], + False, + ), +] + +schema_versions = [ + {"SchemaVersion": "1"}, + {"SchemaVersion": "2"}, + {"SchemaVersion": "3"}, +] + +expected_schema_version = "3" + +property_types: List[Dict[Any, Any]] = [ + {"NoType": ""}, + {"type": "object", "NoProperties": {}}, + { + "type": "object", + "properties": { + "property_1": {"type": "string"}, + "property_2": {"type": "number"}, + }, + }, + { + "type": "object", + "properties": { + "property_1": { + "type": "object", + "properties": { + "property_1_1": {"type": "string"}, + "property_1_2": {"type": "number", "format": "int64"}, + }, + }, + "property_2": {"type": "number"}, + }, + }, + {"type": "array", "NoItems": {}}, + {"type": "array", "items": {"type": "string"}}, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "property_1": {"type": "string"}, + "property_2": {"type": "number"}, + }, + }, + }, + {"type": "string"}, + {"type": "string", "format": "date-time"}, +] + +expected_property_types = [ + "object", + "struct", + "struct", + "struct,property_2:number>", + "array", + "array", + "array>", + "string", + "string[date-time]", +] + + +# patch whole class to avoid actually calling for boto3.client during tests +@patch("databuilder.extractor.eventbridge_extractor.boto3.client", lambda x: None) +class TestEventBridgeExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + self.conf = ConfigFactory.from_dict( + {EventBridgeExtractor.REGISTRY_NAME_KEY: registry_name} + ) + self.maxDiff = None + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_no_content(self) -> None: + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [{"NoContent": {}}] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_unsupported_format(self) -> None: + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [{"Content": json.dumps({})}] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result_openapi_3(self) -> None: + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [{"Content": json.dumps(test_schema_openapi_3)}] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + for expected_table in expected_openapi_3_tables: + self.assertEqual( + expected_table.__repr__(), extractor.extract().__repr__() + ) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_single_result_json_draft_4(self) -> None: + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [ + {"Content": json.dumps(test_schema_json_draft_4)} + ] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + for expected_table in expected_json_draft_4_tables: + self.assertEqual( + expected_table.__repr__(), extractor.extract().__repr__() + ) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(EventBridgeExtractor, "_search_schemas") as mock_search: + mock_search.return_value = [ + {"Content": json.dumps(test_schema_openapi_3)}, + {"Content": json.dumps(test_schema_json_draft_4)}, + ] + + extractor = EventBridgeExtractor() + extractor.init(self.conf) + + for expected_schema in expected_openapi_3_tables: + self.assertEqual( + expected_schema.__repr__(), extractor.extract().__repr__() + ) + + for expected_table in expected_json_draft_4_tables: + self.assertEqual( + expected_table.__repr__(), extractor.extract().__repr__() + ) + + self.assertIsNone(extractor.extract()) + + def test_get_latest_schema_version(self) -> None: + self.assertEqual( + EventBridgeExtractor._get_latest_schema_version(schema_versions), + expected_schema_version, + ) + + def test_get_property_type(self) -> None: + for property_type, expected_property_type in zip( + property_types, expected_property_types + ): + self.assertEqual( + EventBridgeExtractor._get_property_type(property_type), + expected_property_type, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_feast_extractor.py b/databuilder/tests/unit/extractor/test_feast_extractor.py new file mode 100644 index 0000000000..3a24d618b4 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_feast_extractor.py @@ -0,0 +1,182 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import pathlib +import re +import unittest +from datetime import datetime + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.feast_extractor import FeastExtractor +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestFeastExtractor(unittest.TestCase): + expected_created_time = datetime.strptime("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S") + + def setUp(self) -> None: + repo_path = pathlib.Path(__file__).parent.parent.resolve() / "resources/extractor/feast/fs" + os.system(f"cd {repo_path} && feast apply") + + def test_feature_view_extraction(self) -> None: + self._init_extractor(programmatic_description_enabled=False) + + table = self.extractor.extract() + + expected = TableMetadata( + database="feast", + cluster="local", + schema="fs", + name="driver_hourly_stats", + description=None, + columns=[ + ColumnMetadata( + "driver_id", "Internal identifier of the driver", "INT64", 0 + ), + ColumnMetadata("conv_rate", None, "FLOAT", 1), + ColumnMetadata("acc_rate", None, "FLOAT", 2), + ColumnMetadata("avg_daily_trips", None, "INT64", 3), + ], + ) + + self.assertEqual(expected.__repr__(), table.__repr__()) + + def test_feature_table_extraction_with_description_batch(self) -> None: + self._init_extractor(programmatic_description_enabled=True) + + root_tests_path = pathlib.Path(__file__).parent.parent.resolve() + feature_table_definition = self.extractor.extract() + assert isinstance(feature_table_definition, TableMetadata) + + description = self.extractor.extract() + assert isinstance(description, TableMetadata) + expected = DescriptionMetadata( + TestFeastExtractor._strip_margin( + f"""* Created at **{self.expected_created_time}** + |* Tags: + | * is_pii: **true** + |""" + ), + "feature_view_details", + ) + self.assertEqual(expected.__repr__(), description.description.__repr__()) + + batch_source = self.extractor.extract() + assert isinstance(batch_source, TableMetadata) + expected = DescriptionMetadata( + TestFeastExtractor._strip_margin( + f"""``` + |type: BATCH_FILE + |event_timestamp_column: "event_timestamp" + |created_timestamp_column: "created" + |file_options {"{"} + | file_url: "{root_tests_path}/resources/extractor/feast/fs/data/driver_stats.parquet" + |{"}"} + |```""" + ), + "batch_source", + ) + self.assertEqual(expected.__repr__(), batch_source.description.__repr__()) + + def test_feature_table_extraction_with_description_stream(self) -> None: + self._init_extractor(programmatic_description_enabled=True) + root_tests_path = pathlib.Path(__file__).parent.parent.resolve() + + feature_table_definition = self.extractor.extract() + assert isinstance(feature_table_definition, TableMetadata) + + description = self.extractor.extract() + assert isinstance(description, TableMetadata) + expected = DescriptionMetadata( + TestFeastExtractor._strip_margin( + f"""* Created at **{self.expected_created_time}** + |* Tags: + | * is_pii: **true** + |""" + ), + "feature_view_details", + ) + self.assertEqual(expected.__repr__(), description.description.__repr__()) + + batch_source = self.extractor.extract() + assert isinstance(batch_source, TableMetadata) + expected = DescriptionMetadata( + TestFeastExtractor._strip_margin( + f"""``` + |type: BATCH_FILE + |event_timestamp_column: "event_timestamp" + |created_timestamp_column: "created" + |file_options {"{"} + | file_url: "{root_tests_path}/resources/extractor/feast/fs/data/driver_stats.parquet" + |{"}"} + |```""" + ), + "batch_source", + ) + self.assertEqual(expected.__repr__(), batch_source.description.__repr__()) + + stream_source = self.extractor.extract() + assert isinstance(stream_source, TableMetadata) + schema_json = re.sub( + "\n[ \t]*\\|", + "", + """\\\'{\\"type\\": \\"record\\", + |\\"name\\": \\"driver_hourly_stats\\", + |\\"fields\\": [ + | {\\"name\\": \\"conv_rate\\", \\"type\\": \\"float\\"}, + | {\\"name\\": \\"acc_rate\\", \\"type\\": \\"float\\"}, + | {\\"name\\": \\"avg_daily_trips\\", \\"type\\": \\"int\\"}, + | {\\"name\\": \\"datetime\\", \\"type\\": + | {\\"type\\": \\"long\\", \\"logicalType\\": \\"timestamp-micros\\"}}]}\\\'""") + expected = DescriptionMetadata( + TestFeastExtractor._strip_margin( + """``` + |type: STREAM_KAFKA + |event_timestamp_column: "datetime" + |created_timestamp_column: "datetime" + |kafka_options {{ + | bootstrap_servers: "broker1" + | topic: "driver_hourly_stats" + | message_format {{ + | avro_format {{ + | schema_json: "{schema_json}" + | }} + | }} + |}} + |```""").format(schema_json=schema_json), + "stream_source", + ) + print(stream_source.description.__repr__()) + + print(expected.__repr__()) + self.assertEqual(expected.__repr__(), stream_source.description.__repr__()) + + def _init_extractor(self, programmatic_description_enabled: bool = True) -> None: + repository_path = pathlib.Path(__file__).parent.parent.resolve() / "resources/extractor/feast/fs" + conf = { + f"extractor.feast.{FeastExtractor.FEAST_REPOSITORY_PATH}": repository_path, + f"extractor.feast.{FeastExtractor.DESCRIBE_FEATURE_VIEWS}": programmatic_description_enabled, + } + self.extractor = FeastExtractor() + self.extractor.init( + Scoped.get_scoped_conf( + conf=ConfigFactory.from_dict(conf), scope=self.extractor.get_scope() + ) + ) + + @staticmethod + def _strip_margin(text: str) -> str: + return re.sub("\n[ \t]*\\|", "\n", text) + + def tearDown(self) -> None: + root_path = pathlib.Path(__file__).parent.parent.resolve() / "resources/extractor/feast/fs/data" + os.remove(root_path / "online_store.db") + os.remove(root_path / "registry.db") + + +if __name__ == "__main__": + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_generic_extractor.py b/databuilder/tests/unit/extractor/test_generic_extractor.py new file mode 100644 index 0000000000..bb91f3b8c6 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_generic_extractor.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.generic_extractor import GenericExtractor + + +class TestGenericExtractor(unittest.TestCase): + + def test_extraction_with_model_class(self) -> None: + """ + Test Extraction using model class + """ + config_dict = { + 'extractor.generic.extraction_items': [{'timestamp': 10000000}], + 'extractor.generic.model_class': 'databuilder.models.es_last_updated.ESLastUpdated', + } + conf = ConfigFactory.from_dict(config_dict) + + extractor = GenericExtractor() + self.conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + result = extractor.extract() + self.assertEqual(result.timestamp, 10000000) + + def test_extraction_without_model_class(self) -> None: + """ + Test Extraction using model class + """ + config_dict = { + 'extractor.generic.extraction_items': [{'foo': 1}, {'bar': 2}], + } + conf = ConfigFactory.from_dict(config_dict) + + extractor = GenericExtractor() + self.conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + self.assertEqual(extractor.extract(), {'foo': 1}) + self.assertEqual(extractor.extract(), {'bar': 2}) diff --git a/databuilder/tests/unit/extractor/test_generic_usage_extractor.py b/databuilder/tests/unit/extractor/test_generic_usage_extractor.py new file mode 100644 index 0000000000..aff4968661 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_generic_usage_extractor.py @@ -0,0 +1,99 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.generic_usage_extractor import GenericUsageExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_column_usage import ColumnReader, TableColumnUsage + + +class TestGenericUsageExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'extractor.generic_usage.{GenericUsageExtractor.POPULARITY_TABLE_DATABASE}': 'WhateverNameOfYourDb', + f'extractor.generic_usage.{GenericUsageExtractor.POPULARTIY_TABLE_SCHEMA}': 'WhateverNameOfYourSchema', + f'extractor.generic_usage.{GenericUsageExtractor.POPULARITY_TABLE_NAME}': 'WhateverNameOfYourTable', + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = GenericUsageExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + + sql_execute.return_value = [{ + 'database': 'gold', + 'schema': 'scm', + 'name': 'foo', + 'user_email': 'john@example.com', + 'read_count': 1 + }] + + expected = TableColumnUsage( + col_readers=[ + ColumnReader( + database='snowflake', + cluster='gold', + schema='scm', + table='foo', + column='*', + user_email='john@example.com', + read_count=1 + ) + ] + ) + + extractor = GenericUsageExtractor() + extractor.init(self.conf) + actual = extractor.extract() + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + +class TestGenericUsageExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where user_email != 'wrong_user@email.com' + """ + + config_dict = { + GenericUsageExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = GenericUsageExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_glue_extractor.py b/databuilder/tests/unit/extractor/test_glue_extractor.py new file mode 100644 index 0000000000..c33e7e57e3 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_glue_extractor.py @@ -0,0 +1,305 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder.extractor.glue_extractor import GlueExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +test_table = { + 'Name': 'test_table', + 'DatabaseName': 'test_schema', + 'Description': 'a table for testing', + 'StorageDescriptor': { + 'Columns': [ + { + 'Name': 'col_id1', + 'Type': 'bigint', + 'Comment': 'description of id1' + }, + { + 'Name': 'col_id2', + 'Type': 'bigint', + 'Comment': 'description of id2' + }, + { + 'Name': 'is_active', + 'Type': 'boolean' + }, + { + 'Name': 'source', + 'Type': 'varchar', + 'Comment': 'description of source' + }, + { + 'Name': 'etl_created_at', + 'Type': 'timestamp', + 'Comment': 'description of etl_created_at' + }, + { + 'Name': 'ds', + 'Type': 'varchar' + } + ] + }, + 'PartitionKeys': [ + { + 'Name': 'partition_key1', + 'Type': 'string', + 'Comment': 'description of partition_key1' + }, + ], + 'TableType': 'EXTERNAL_TABLE', +} + + +# patch whole class to avoid actually calling for boto3.client during tests +@patch('databuilder.extractor.glue_extractor.boto3.client', lambda x: None) +class TestGlueExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + self.conf = ConfigFactory.from_dict({}) + self.maxDiff = None + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(GlueExtractor, '_search_tables'): + extractor = GlueExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(GlueExtractor, '_search_tables') as mock_search: + mock_search.return_value = [ + { + 'Name': 'test_table', + 'DatabaseName': 'test_schema', + 'Description': 'a table for testing', + 'StorageDescriptor': { + 'Columns': [ + { + 'Name': 'col_id1', + 'Type': 'bigint', + 'Comment': 'description of id1' + }, + { + 'Name': 'col_id2', + 'Type': 'bigint', + 'Comment': 'description of id2' + }, + { + 'Name': 'is_active', + 'Type': 'boolean' + }, + { + 'Name': 'source', + 'Type': 'varchar', + 'Comment': 'description of source' + }, + { + 'Name': 'etl_created_at', + 'Type': 'timestamp', + 'Comment': 'description of etl_created_at' + }, + { + 'Name': 'ds', + 'Type': 'varchar' + } + ] + }, + 'PartitionKeys': [ + { + 'Name': 'partition_key1', + 'Type': 'string', + 'Comment': 'description of partition_key1' + }, + ], + 'TableType': 'EXTERNAL_TABLE', + } + ] + + extractor = GlueExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('glue', 'gold', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + ColumnMetadata('partition_key1', 'description of partition_key1', 'string', 6), + ], False) + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(GlueExtractor, '_search_tables') as mock_search: + mock_search.return_value = [ + test_table, + { + 'Name': 'test_table2', + 'DatabaseName': 'test_schema1', + 'Description': 'test table 2', + 'StorageDescriptor': { + 'Columns': [ + { + 'Name': 'col_name', + 'Type': 'varchar', + 'Comment': 'description of col_name' + }, + { + 'Name': 'col_name2', + 'Type': 'varchar', + 'Comment': 'description of col_name2' + } + ] + }, + 'TableType': 'EXTERNAL_TABLE', + }, + { + 'Name': 'test_table3', + 'DatabaseName': 'test_schema2', + 'StorageDescriptor': { + 'Columns': [ + { + 'Name': 'col_id3', + 'Type': 'varchar', + 'Comment': 'description of col_id3' + }, + { + 'Name': 'col_name3', + 'Type': 'varchar', + 'Comment': 'description of col_name3' + } + ] + }, + 'Parameters': {'comment': 'description of test table 3 from comment'}, + 'TableType': 'EXTERNAL_TABLE', + }, + { + 'Name': 'test_view1', + 'DatabaseName': 'test_schema1', + 'Description': 'test view 1', + 'StorageDescriptor': { + 'Columns': [ + { + 'Name': 'col_id3', + 'Type': 'varchar', + 'Comment': 'description of col_id3' + }, + { + 'Name': 'col_name3', + 'Type': 'varchar', + 'Comment': 'description of col_name3' + } + ] + }, + 'TableType': 'VIRTUAL_VIEW', + }, + ] + + extractor = GlueExtractor() + extractor.init(self.conf) + + expected = TableMetadata('glue', 'gold', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + ColumnMetadata('partition_key1', 'description of partition_key1', 'string', 6), + ], False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('glue', 'gold', 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)], False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('glue', 'gold', 'test_schema2', 'test_table3', + 'description of test table 3 from comment', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('glue', 'gold', 'test_schema1', 'test_view1', 'test view 1', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], True) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_resource_link_result(self) -> None: + with patch.object(GlueExtractor, '_search_tables') as mock_search: + mock_search.return_value = [ + test_table, + { + "Name": "test_resource_link", + "DatabaseName": "test_schema", + "TargetTable": { + "CatalogId": "111111111111", + "DatabaseName": "test_schema_external", + "Name": "test_table" + }, + "CatalogId": "222222222222" + } + ] + + extractor = GlueExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('glue', 'gold', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + ColumnMetadata('partition_key1', 'description of partition_key1', 'string', 6), + ], False) + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_partition_badge(self) -> None: + with patch.object(GlueExtractor, '_search_tables') as mock_search: + mock_search.return_value = [test_table] + + extractor = GlueExtractor() + extractor.init(conf=ConfigFactory.from_dict({ + GlueExtractor.PARTITION_BADGE_LABEL_KEY: "partition_key", + })) + actual = extractor.extract() + expected = TableMetadata('glue', 'gold', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + ColumnMetadata( + 'partition_key1', + 'description of partition_key1', + 'string', + 6, + ["partition_key"], + ), + ], False) + self.assertEqual(expected.__repr__(), actual.__repr__()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_hive_table_last_updated_extractor.py b/databuilder/tests/unit/extractor/test_hive_table_last_updated_extractor.py new file mode 100644 index 0000000000..6ef31b755b --- /dev/null +++ b/databuilder/tests/unit/extractor/test_hive_table_last_updated_extractor.py @@ -0,0 +1,132 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import logging +import unittest +from datetime import datetime +from typing import ( + Iterable, Iterator, Optional, TypeVar, +) + +from mock import MagicMock, patch +from pyhocon import ConfigFactory +from pytz import UTC + +from databuilder.extractor.hive_table_last_updated_extractor import HiveTableLastUpdatedExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.filesystem.filesystem import FileSystem +from databuilder.filesystem.metadata import FileMetadata +from databuilder.models.table_last_updated import TableLastUpdated + +T = TypeVar('T') + + +def null_iterator(items: Iterable[T]) -> Iterator[Optional[T]]: + """ + Returns an infinite iterator that returns the items from items, + then infinite Nones. Required because Extractor.extract is expected + to return None when it is exhausted, not terminate. + """ + return itertools.chain(iter(items), itertools.repeat(None)) + + +class TestHiveTableLastUpdatedExtractor(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + def test_extraction_with_empty_query_result(self) -> None: + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'filesystem.{FileSystem.DASK_FILE_SYSTEM}': MagicMock() + } + conf = ConfigFactory.from_dict(config_dict) + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = HiveTableLastUpdatedExtractor() + extractor.init(conf) + + result = extractor.extract() + self.assertEqual(result, None) + + def test_extraction_with_partition_table_result(self) -> None: + config_dict = { + f'filesystem.{FileSystem.DASK_FILE_SYSTEM}': MagicMock() + } + conf = ConfigFactory.from_dict(config_dict) + + pt_alchemy_extractor_instance = MagicMock() + non_pt_alchemy_extractor_instance = MagicMock() + with patch.object(HiveTableLastUpdatedExtractor, '_get_partitioned_table_sql_alchemy_extractor', + return_value=pt_alchemy_extractor_instance), \ + patch.object(HiveTableLastUpdatedExtractor, '_get_non_partitioned_table_sql_alchemy_extractor', + return_value=non_pt_alchemy_extractor_instance): + pt_alchemy_extractor_instance.extract = MagicMock(side_effect=null_iterator([{ + 'schema': 'foo_schema', + 'table_name': 'table_1', + 'last_updated_time': 1 + }, { + 'schema': 'foo_schema', + 'table_name': 'table_2', + 'last_updated_time': 2 + }])) + + non_pt_alchemy_extractor_instance.extract = MagicMock(return_value=None) + + extractor = HiveTableLastUpdatedExtractor() + extractor.init(conf) + + result = extractor.extract() + expected = TableLastUpdated(schema='foo_schema', table_name='table_1', last_updated_time_epoch=1, + db='hive', cluster='gold') + self.assertEqual(result.__repr__(), expected.__repr__()) + result = extractor.extract() + expected = TableLastUpdated(schema='foo_schema', table_name='table_2', last_updated_time_epoch=2, + db='hive', cluster='gold') + self.assertEqual(result.__repr__(), expected.__repr__()) + + self.assertIsNone(extractor.extract()) + + def test_extraction(self) -> None: + old_datetime = datetime(2018, 8, 14, 4, 12, 3, tzinfo=UTC) + new_datetime = datetime(2018, 11, 14, 4, 12, 3, tzinfo=UTC) + + fs = MagicMock() + fs.ls = MagicMock(return_value=['/foo/bar', '/foo/baz']) + fs.is_file = MagicMock(return_value=True) + fs.info = MagicMock(side_effect=[ + FileMetadata(path='/foo/bar', last_updated=old_datetime, size=15093), + FileMetadata(path='/foo/baz', last_updated=new_datetime, size=15094) + ]) + + pt_alchemy_extractor_instance = MagicMock() + non_pt_alchemy_extractor_instance = MagicMock() + + with patch.object(HiveTableLastUpdatedExtractor, '_get_partitioned_table_sql_alchemy_extractor', + return_value=pt_alchemy_extractor_instance), \ + patch.object(HiveTableLastUpdatedExtractor, '_get_non_partitioned_table_sql_alchemy_extractor', + return_value=non_pt_alchemy_extractor_instance), \ + patch.object(HiveTableLastUpdatedExtractor, '_get_filesystem', + return_value=fs): + pt_alchemy_extractor_instance.extract = MagicMock(return_value=None) + + non_pt_alchemy_extractor_instance.extract = MagicMock(side_effect=null_iterator([{ + 'schema': 'foo_schema', + 'table_name': 'table_1', + 'location': '/foo/bar' + }])) + + extractor = HiveTableLastUpdatedExtractor() + extractor.init(ConfigFactory.from_dict({})) + + result = extractor.extract() + expected = TableLastUpdated(schema='foo_schema', table_name='table_1', + last_updated_time_epoch=1542168723, + db='hive', cluster='gold') + self.assertEqual(result.__repr__(), expected.__repr__()) + + self.assertIsNone(extractor.extract()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_hive_table_metadata_extractor.py b/databuilder/tests/unit/extractor/test_hive_table_metadata_extractor.py new file mode 100644 index 0000000000..d2ffbe4227 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_hive_table_metadata_extractor.py @@ -0,0 +1,275 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.hive_table_metadata_extractor import HiveTableMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestHiveTableMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'), \ + patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm', + return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT): + extractor = HiveTableMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \ + patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm', + return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT): + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'is_view': 0} + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2, + 'is_partition_col': 1}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5, + 'is_partition_col': 0}, table) + ] + + extractor = HiveTableMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('hive', 'gold', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2, ['partition column']), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', + 4), + ColumnMetadata('ds', None, 'varchar', 5)], + is_view=False) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \ + patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm', + return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT): + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'is_view': 0} + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'is_view': 0} + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'is_view': 0} + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0, + 'is_partition_col': 1}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5, + 'is_partition_col': 0}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0, + 'is_partition_col': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1, + 'is_partition_col': 0}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0, + 'is_partition_col': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1, + 'is_partition_col': 0}, table2) + ] + + extractor = HiveTableMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('hive', 'gold', 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0, + ['partition column']), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', + 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], + is_view=False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('hive', 'gold', 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)], + is_view=False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('hive', 'gold', 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], + is_view=False) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestHiveTableMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + AND d.NAME IN ('test_schema1', 'test_schema2') + AND t.TBL_NAME NOT REGEXP '^[0-9]+'""" + + config_dict = { + HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'), \ + patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm', + return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT): + extractor = HiveTableMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + def test_hive_sql_statement_with_custom_sql(self) -> None: + """ + Test Extraction by providing a custom sql + :return: + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'), \ + patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm', + return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT): + config_dict = { + HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + HiveTableMetadataExtractor.EXTRACT_SQL: + 'select sth for test {where_clause_suffix}' + } + conf = ConfigFactory.from_dict(config_dict) + extractor = HiveTableMetadataExtractor() + extractor.init(conf) + self.assertTrue('select sth for test' in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_kafka_schema_registry_extractor.py b/databuilder/tests/unit/extractor/test_kafka_schema_registry_extractor.py new file mode 100644 index 0000000000..c730909597 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_kafka_schema_registry_extractor.py @@ -0,0 +1,225 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import List +from unittest.mock import MagicMock + +from mock import patch +from pyhocon import ConfigFactory +from schema_registry.client.schema import AvroSchema +from schema_registry.client.utils import SchemaVersion + +from databuilder.extractor.kafka_schema_registry_extractor import KafkaSchemaRegistryExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +INPUT_SCHEMAS: List[SchemaVersion] = [ + SchemaVersion( + subject="subject1", + schema_id=1, + schema=AvroSchema({ + 'type': 'record', + 'namespace': 'com.kubertenes', + 'name': 'AvroDeployment', + 'fields': [ + {'name': 'image', 'type': 'string'}, + {'name': 'replicas', 'type': 'int'}, + {'name': 'port', 'type': 'int'} + ] + }), + version=1 + ), + SchemaVersion( + subject="subject2", + schema_id=2, + schema=AvroSchema({ + 'namespace': 'my.com.ns', + 'name': 'myrecord', + 'type': 'record', + 'fields': [ + {'name': 'uid', 'type': 'int'}, + {'name': 'somefield', 'type': 'string'}, + {'name': 'options', + 'type': + {'type': 'array', + 'items': + {'type': 'record', + 'name': 'lvl2_record', + 'fields': + [ + {'name': 'item1_lvl2', 'type': 'string'}, + {'name': 'item2_lvl2', + 'type': + { + 'type': 'array', + 'items': + { + 'type': 'record', + 'name': 'lvl3_record', + 'fields': [ + {'name': 'item1_lvl3', + 'type': 'string'}, + {'name': 'item2_lvl3', + 'type': 'string'} + ] + } + } + } + ] + } + } + } + ] + }), + version=1 + ), + SchemaVersion( + subject="subject3", + schema_id=3, + schema=AvroSchema({ + 'type': 'record', + 'name': 'milanoRecord', + 'namespace': 'com.landoop.telecom.telecomitalia.grid', + 'doc': ('Schema for Grid for Telecommunications ' + 'Data from Telecom Italia.'), + 'fields': [ + {'name': 'SquareId', + 'type': 'int', + 'doc': (' The id of the square that ' + 'is part of the Milano GRID')}, + {'name': 'Polygon', + 'type': + {'type': 'array', + 'items': { + 'type': 'record', + 'name': 'coordinates', + 'fields': [ + {'name': 'longitude', 'type': 'double'}, + {'name': 'latitude', 'type': 'double'}]}}}] + }), + version=5 + ) +] +EXPECTED_TABLES: List[TableMetadata] = [ + TableMetadata( + "kafka_schema_registry", + "com.kubertenes", + "subject1", + "AvroDeployment", + None, + [ + ColumnMetadata("image", None, "string", 0), + ColumnMetadata("replicas", None, "int", 1), + ColumnMetadata("port", None, "int", 2), + ], + False, + ), + TableMetadata( + "kafka_schema_registry", + "my.com.ns", + "subject2", + "myrecord", + None, + [ + ColumnMetadata("uid", None, "int", 0), + ColumnMetadata("somefield", None, "string", 1), + ColumnMetadata( + "options", + None, + ("array>>>"), + 2, + ), + ], + False, + ), + TableMetadata( + "kafka_schema_registry", + "com.landoop.telecom.telecomitalia.grid", + "subject3", + "milanoRecord", + "Schema for Grid for Telecommunications Data from Telecom Italia.", + [ + ColumnMetadata( + "SquareId", + " The id of the square that is part of the Milano GRID", + "int", + 0, + ), + ColumnMetadata( + "Polygon", + None, + "array>", + 1, + ), + ], + False, + ), +] + + +class TestKafkaSchemaRegistryExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + self.conf = ConfigFactory.from_dict( + { + KafkaSchemaRegistryExtractor.REGISTRY_URL_KEY: ( + "http://example.com" + ) + } + ) + self.maxDiff = None + + @patch( + ("databuilder.extractor.kafka_schema_registry_extractor." + "SchemaRegistryClient.get_subjects") + ) + def test_extraction_with_empty_query_result( + self, all_subjects: MagicMock + ) -> None: + """ + Test extraction with empty result from query + """ + all_subjects.return_value = [] + + extractor = KafkaSchemaRegistryExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + @patch( + ("databuilder.extractor.kafka_schema_registry_extractor." + "SchemaRegistryClient.get_schema") + ) + @patch( + ("databuilder.extractor.kafka_schema_registry_extractor." + "SchemaRegistryClient.get_subjects") + ) + def test_extraction_successfully( + self, + all_subjects: MagicMock, + schema: MagicMock, + ) -> None: + """ + Test extraction with the given schemas finishes successfully + """ + all_subjects.return_value = [subj.subject for subj in INPUT_SCHEMAS] + schema.side_effect = INPUT_SCHEMAS + + extractor = KafkaSchemaRegistryExtractor() + extractor.init(self.conf) + + for expected_table in EXPECTED_TABLES: + self.assertEqual( + expected_table.__repr__(), extractor.extract().__repr__() + ) + + self.assertIsNone(extractor.extract()) + + +if __name__ == "__main__": + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_kafka_source_extractor.py b/databuilder/tests/unit/extractor/test_kafka_source_extractor.py new file mode 100644 index 0000000000..f9ec3ddd25 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_kafka_source_extractor.py @@ -0,0 +1,52 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.kafka_source_extractor import KafkaSourceExtractor + + +class TestKafkaSourceExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + config_dict = { + 'extractor.kafka_source.consumer_config': {'"group.id"': 'consumer-group', '"enable.auto.commit"': False}, + f'extractor.kafka_source.{KafkaSourceExtractor.RAW_VALUE_TRANSFORMER}': + 'databuilder.transformer.base_transformer.NoopTransformer', + f'extractor.kafka_source.{KafkaSourceExtractor.TOPIC_NAME_LIST}': ['test-topic'], + f'extractor.kafka_source.{KafkaSourceExtractor.CONSUMER_TOTAL_TIMEOUT_SEC}': 1, + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_consume_success(self) -> None: + kafka_extractor = KafkaSourceExtractor() + kafka_extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=kafka_extractor.get_scope())) + + with patch.object(kafka_extractor, 'consumer') as mock_consumer: + mock_poll = MagicMock() + mock_poll.error.return_value = False + # only return once + mock_poll.value.side_effect = ['msg'] + mock_consumer.poll.return_value = mock_poll + + records = kafka_extractor.consume() + self.assertEqual(len(records), 1) + + def test_consume_fail(self) -> None: + kafka_extractor = KafkaSourceExtractor() + kafka_extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=kafka_extractor.get_scope())) + + with patch.object(kafka_extractor, 'consumer') as mock_consumer: + mock_poll = MagicMock() + mock_poll.error.return_value = True + mock_consumer.poll.return_value = mock_poll + + records = kafka_extractor.consume() + self.assertEqual(len(records), 0) diff --git a/databuilder/tests/unit/extractor/test_mssql_metadata_extractor.py b/databuilder/tests/unit/extractor/test_mssql_metadata_extractor.py new file mode 100644 index 0000000000..2b18199f12 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_mssql_metadata_extractor.py @@ -0,0 +1,312 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.mssql_metadata_extractor import MSSQLMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestMSSQLMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}': 'MY_CLUSTER', + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME}': False, + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.DATABASE_KEY}': 'mssql', + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY}': '' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema_name': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('mssql', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], tags='test_schema') + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema_name': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + table1 = {'schema_name': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + table2 = {'schema_name': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], tags='test_schema1') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)], + tags='test_schema1') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], tags='test_schema2') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestMSSQLMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + ('dbo', 'sys') + """ + + config_dict = { + MSSQLMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + MSSQLMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(MSSQLMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorTableCatalogEnabled(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + MSSQLMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue('DB_NAME()' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_mysql_search_data_extractor.py b/databuilder/tests/unit/extractor/test_mysql_search_data_extractor.py new file mode 100644 index 0000000000..16c5b3d601 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_mysql_search_data_extractor.py @@ -0,0 +1,257 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any +from unittest.mock import MagicMock, patch + +from amundsen_rds.models.badge import Badge +from amundsen_rds.models.cluster import Cluster +from amundsen_rds.models.column import ColumnDescription, TableColumn +from amundsen_rds.models.dashboard import ( + Dashboard, DashboardChart, DashboardCluster, DashboardDescription, DashboardExecution, DashboardGroup, + DashboardGroupDescription, DashboardQuery, DashboardUsage, +) +from amundsen_rds.models.database import Database +from amundsen_rds.models.schema import Schema, SchemaDescription +from amundsen_rds.models.table import ( + Table, TableDescription, TableProgrammaticDescription, TableTimestamp, TableUsage, +) +from amundsen_rds.models.tag import Tag +from amundsen_rds.models.user import User +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor import mysql_search_data_extractor +from databuilder.extractor.mysql_search_data_extractor import MySQLSearchDataExtractor +from databuilder.models.dashboard_elasticsearch_document import DashboardESDocument +from databuilder.models.table_elasticsearch_document import TableESDocument +from databuilder.models.user_elasticsearch_document import UserESDocument + + +class MyTestCase(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + + @patch.object(mysql_search_data_extractor, '_table_search_query') + @patch.object(mysql_search_data_extractor, 'sessionmaker') + @patch.object(mysql_search_data_extractor, 'create_engine') + def test_table_search(self, + mock_create_engine: Any, + mock_session_maker: Any, + mock_table_search_query: Any) -> None: + database = Database(rk='test_database_key', name='test_database') + cluster = Cluster(rk='test_cluster_key', name='test_cluster') + cluster.database = database + + schema = Schema(rk='test_schema_key', name='test_schema') + schema.description = SchemaDescription(rk='test_schema_description_key', description='test_schema_description') + schema.cluster = cluster + + table = Table(rk='test_table_key', name='test_table') + table.schema = schema + + table.description = TableDescription(rk='test_table_description_key', description='test_table_description') + table.programmatic_descriptions = [TableProgrammaticDescription(rk='test_table_prog_description_key', + description='test_table_prog_description')] + + table.timestamp = TableTimestamp(rk='test_table_timestamp_key', last_updated_timestamp=123456789) + + column1 = TableColumn(rk='test_col1_key', name='test_col1') + column2 = TableColumn(rk='test_col2_key', name='test_col2') + column3 = TableColumn(rk='test_col3_key', name='test_col3') + column1.description = ColumnDescription(rk='test_col1_description_key', + description='test_col1_description') + column2.description = ColumnDescription(rk='test_col2_description_key', + description='test_col2_description') + table.columns = [column1, column2, column3] + + usage1 = TableUsage(user_rk='test_user1_key', table_rk='test_table_key', read_count=5) + usage2 = TableUsage(user_rk='test_user2_key', table_rk='test_table_key', read_count=10) + table.usage = [usage1, usage2] + + tags = [Tag(rk='test_tag', tag_type='default')] + table.tags = tags + + badges = [Badge(rk='test_badge')] + table.badges = badges + + tables = [table] + + expected_dict = dict(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table', + display_name='test_schema.test_table', + key='test_table_key', + description='test_table_description', + last_updated_timestamp=123456789, + column_names=['test_col1', 'test_col2', 'test_col3'], + column_descriptions=['test_col1_description', 'test_col2_description', ''], + total_usage=15, + unique_usage=2, + tags=['test_tag'], + badges=['test_badge'], + schema_description='test_schema_description', + programmatic_descriptions=['test_table_prog_description']) + + config_dict = { + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.CONN_STRING}': 'test_conn_string', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.ENTITY_TYPE}': 'table', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.JOB_PUBLISH_TAG}': 'test_tag', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.MODEL_CLASS}': + 'databuilder.models.table_elasticsearch_document.TableESDocument' + } + self.conf = ConfigFactory.from_dict(config_dict) + + extractor = MySQLSearchDataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + mock_table_search_query.side_effect = [tables, None] + + actual_obj = extractor.extract() + + self.assertIsInstance(actual_obj, TableESDocument) + self.assertDictEqual(vars(actual_obj), expected_dict) + + @patch.object(mysql_search_data_extractor, '_user_search_query') + @patch.object(mysql_search_data_extractor, 'sessionmaker') + @patch.object(mysql_search_data_extractor, 'create_engine') + def test_user_search(self, + mock_create_engine: Any, + mock_session_maker: Any, + mock_user_search_query: Any) -> None: + user = User(rk='test_user_key', + email='test_user@email.com', + first_name='test_first_name', + last_name='test_last_name', + full_name='test_full_name', + github_username='test_github_username', + team_name='test_team_name', + employee_type='test_employee_type', + slack_id='test_slack_id', + role_name='test_role_name', + is_active=True) + manager = User(rk='test_manager_key', email='test_manager@email.com') + user.manager = manager + + expected_dict = dict(email='test_user@email.com', + first_name='test_first_name', + last_name='test_last_name', + full_name='test_full_name', + github_username='test_github_username', + team_name='test_team_name', + employee_type='test_employee_type', + manager_email='test_manager@email.com', + slack_id='test_slack_id', + role_name='test_role_name', + is_active=True, + total_read=30, + total_own=2, + total_follow=2) + + config_dict = { + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.CONN_STRING}': 'test_conn_string', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.ENTITY_TYPE}': 'user', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.JOB_PUBLISH_TAG}': 'test_tag', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.MODEL_CLASS}': + 'databuilder.models.user_elasticsearch_document.UserESDocument' + } + self.conf = ConfigFactory.from_dict(config_dict) + + extractor = MySQLSearchDataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + query_results = [MagicMock(User=user, + table_read_count=10, + dashboard_read_count=20, + table_own_count=1, + dashboard_own_count=1, + table_follow_count=1, + dashboard_follow_count=1)] + mock_user_search_query.side_effect = [query_results, None] + + actual_obj = extractor.extract() + + self.assertIsInstance(actual_obj, UserESDocument) + self.assertDictEqual(vars(actual_obj), expected_dict) + + @patch.object(mysql_search_data_extractor, '_dashboard_search_query') + @patch.object(mysql_search_data_extractor, 'sessionmaker') + @patch.object(mysql_search_data_extractor, 'create_engine') + def test_dashboard_search(self, + mock_create_engine: Any, + mock_session_maker: Any, + mock_dashboard_search_query: Any) -> None: + dashboard = Dashboard(rk='test_dashboard//key', name='test_dashboard', dashboard_url='test://dashboard_url') + dashboard.description = DashboardDescription(rk='test_dashboard_description_key', + description='test_dashboard_description') + + group = DashboardGroup(rk='test_group_key', name='test_group', dashboard_group_url='test://group_url') + group.description = DashboardGroupDescription(rk='test_group_description_key', + description='test_group_description') + dashboard.group = group + + cluster = DashboardCluster(rk='test_cluster_key', name='test_cluster') + group.cluster = cluster + + last_exec = DashboardExecution(rk='test_dashboard_exec_key/_last_successful_execution', timestamp=123456789) + dashboard.execution = [last_exec] + + usage1 = DashboardUsage(user_rk='test_user1_key', dashboard_rk='test_dashboard_key', read_count=10) + usage2 = DashboardUsage(user_rk='test_user2_key', dashboard_rk='test_dashboard_key', read_count=5) + dashboard.usage = [usage1, usage2] + + query = DashboardQuery(rk='test_query_key', name='test_query') + query.charts = [DashboardChart(rk='test_chart_key', name='test_chart')] + dashboard.queries = [query] + + tags = [Tag(rk='test_tag', tag_type='default')] + dashboard.tags = tags + + badges = [Badge(rk='test_badge')] + dashboard.badges = badges + + dashboards = [dashboard] + + expected_dict = dict(group_name='test_group', + name='test_dashboard', + description='test_dashboard_description', + product='test', + cluster='test_cluster', + group_description='test_group_description', + query_names=['test_query'], + chart_names=['test_chart'], + group_url='test://group_url', + url='test://dashboard_url', + uri='test_dashboard//key', + last_successful_run_timestamp=123456789, + total_usage=15, + tags=['test_tag'], + badges=['test_badge']) + + config_dict = { + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.CONN_STRING}': 'test_conn_string', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.ENTITY_TYPE}': 'dashboard', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.JOB_PUBLISH_TAG}': 'test_tag', + f'extractor.mysql_search_data.{MySQLSearchDataExtractor.MODEL_CLASS}': + 'databuilder.models.dashboard_elasticsearch_document.DashboardESDocument' + } + self.conf = ConfigFactory.from_dict(config_dict) + + extractor = MySQLSearchDataExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + mock_dashboard_search_query.side_effect = [dashboards, None] + + actual_obj = extractor.extract() + + self.assertIsInstance(actual_obj, DashboardESDocument) + self.assertDictEqual(vars(actual_obj), expected_dict) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_neo4j_extractor.py b/databuilder/tests/unit/extractor/test_neo4j_extractor.py new file mode 100644 index 0000000000..b46371dd6d --- /dev/null +++ b/databuilder/tests/unit/extractor/test_neo4j_extractor.py @@ -0,0 +1,125 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any + +from mock import patch +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.neo4j_extractor import Neo4jExtractor +from databuilder.models.table_elasticsearch_document import TableESDocument + + +class TestNeo4jExtractor(unittest.TestCase): + + def setUp(self) -> None: + config_dict = { + f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'bolt://example.com:7687', + f'extractor.neo4j.{Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY}': 'TEST_QUERY', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'TEST_USER', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'TEST_PW', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50, + } + + self.conf = ConfigFactory.from_dict(config_dict) + + def text_extraction_with_empty_query_result(self: Any) -> None: + """ + Test Extraction with empty results from query + """ + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + extractor.results = [''] + result = extractor.extract() + self.assertIsNone(result) + + def test_extraction_with_single_query_result(self: Any) -> None: + """ + Test Extraction with single result from query + """ + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + extractor.results = ['test_result'] + result = extractor.extract() + self.assertEqual(result, 'test_result') + + # Ensure second result is None + result = extractor.extract() + self.assertIsNone(result) + + def test_extraction_with_multiple_query_result(self: Any) -> None: + """ + Test Extraction with multiple result from query + """ + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + extractor.results = ['test_result1', 'test_result2', 'test_result3'] + + result = extractor.extract() + self.assertEqual(result, 'test_result1') + + result = extractor.extract() + self.assertEqual(result, 'test_result2') + + result = extractor.extract() + self.assertEqual(result, 'test_result3') + + # Ensure next result is None + result = extractor.extract() + self.assertIsNone(result) + + def test_extraction_with_model_class(self: Any) -> None: + """ + Test Extraction using model class + """ + config_dict = { + f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'bolt://example.com:7687', + f'extractor.neo4j.{Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY}': 'TEST_QUERY', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'TEST_USER', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'TEST_PW', + f'extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50, + f'extractor.neo4j.{Neo4jExtractor.MODEL_CLASS_CONFIG_KEY}': + 'databuilder.models.table_elasticsearch_document.TableESDocument' + } + + self.conf = ConfigFactory.from_dict(config_dict) + + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result_dict = dict(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table_name', + display_name='test_schema.test_table_name', + key='test_table_key', + description='test_table_description', + last_updated_timestamp=123456789, + column_names=['test_col1', 'test_col2', 'test_col3'], + column_descriptions=['test_description1', 'test_description2', ''], + total_usage=100, + unique_usage=5, + tags=['hive'], + badges=['badge1'], + schema_description='schema_description', + programmatic_descriptions=['TEST']) + + extractor.results = [result_dict] + result_obj = extractor.extract() + + self.assertIsInstance(result_obj, TableESDocument) + self.assertDictEqual(vars(result_obj), result_dict) diff --git a/databuilder/tests/unit/extractor/test_neo4j_search_data_extractor.py b/databuilder/tests/unit/extractor/test_neo4j_search_data_extractor.py new file mode 100644 index 0000000000..c8b77426ea --- /dev/null +++ b/databuilder/tests/unit/extractor/test_neo4j_search_data_extractor.py @@ -0,0 +1,69 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any +from unittest.mock import patch + +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.neo4j_extractor import Neo4jExtractor +from databuilder.extractor.neo4j_search_data_extractor import Neo4jSearchDataExtractor +from databuilder.publisher.neo4j_csv_publisher import JOB_PUBLISH_TAG + + +class TestNeo4jExtractor(unittest.TestCase): + + def test_adding_filter(self: Any) -> None: + extractor = Neo4jSearchDataExtractor() + actual = extractor._add_publish_tag_filter('foo', 'MATCH (table:Table) {publish_tag_filter} RETURN table') + + self.assertEqual(actual, """MATCH (table:Table) WHERE table.published_tag = 'foo' RETURN table""") + + def test_not_adding_filter(self: Any) -> None: + extractor = Neo4jSearchDataExtractor() + actual = extractor._add_publish_tag_filter('', 'MATCH (table:Table) {publish_tag_filter} RETURN table') + + self.assertEqual(actual, """MATCH (table:Table) RETURN table""") + + def test_default_search_query(self: Any) -> None: + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jSearchDataExtractor() + conf = ConfigFactory.from_dict({ + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': + 'bolt://example.com:7687', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'test-user', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'test-passwd', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50, + f'extractor.search_data.{Neo4jSearchDataExtractor.ENTITY_TYPE}': 'dashboard', + }) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + self.assertEqual(extractor.cypher_query, + Neo4jSearchDataExtractor.DEFAULT_NEO4J_DASHBOARD_CYPHER_QUERY.format( + publish_tag_filter='')) + + def test_default_search_query_with_tag(self: Any) -> None: + with patch.object(GraphDatabase, 'driver'): + extractor = Neo4jSearchDataExtractor() + conf = ConfigFactory.from_dict({ + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': + 'bolt://example.com:7687', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'test-user', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'test-passwd', + f'extractor.search_data.extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50, + f'extractor.search_data.{Neo4jSearchDataExtractor.ENTITY_TYPE}': 'dashboard', + f'extractor.search_data.{JOB_PUBLISH_TAG}': 'test-date', + }) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + + self.assertEqual(extractor.cypher_query, + Neo4jSearchDataExtractor.DEFAULT_NEO4J_DASHBOARD_CYPHER_QUERY.format( + publish_tag_filter="""WHERE dashboard.published_tag = 'test-date'""")) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_openlineage_extractor.py b/databuilder/tests/unit/extractor/test_openlineage_extractor.py new file mode 100644 index 0000000000..86c50ddb0f --- /dev/null +++ b/databuilder/tests/unit/extractor/test_openlineage_extractor.py @@ -0,0 +1,79 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.openlineage_extractor import OpenLineageTableLineageExtractor + + +class TestOpenlineageExtractor(unittest.TestCase): + + def test_amundsen_dataset_key(self) -> None: + """ + Test _amundsen_dataset_key method + """ + config_dict = { + f'extractor.openlineage_tablelineage.{OpenLineageTableLineageExtractor.TABLE_LINEAGE_FILE_LOCATION}': + 'example/sample_data/openlineage/sample_openlineage_events.ndjson', + f'extractor.openlineage_tablelineage.{OpenLineageTableLineageExtractor.CLUSTER_NAME}': 'datalab', + + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = OpenLineageTableLineageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + mock_dataset = {'name': 'mock_table', + 'namespace': 'postgresql', + 'database': 'testdb'} + + self.assertEqual('postgresql://datalab.testdb/mock_table', extractor._amundsen_dataset_key(mock_dataset)) + extractor.ol_namespace_override = 'hive' + self.assertEqual('hive://datalab.testdb/mock_table', extractor._amundsen_dataset_key(mock_dataset)) + + def test_extraction_with_model_class(self) -> None: + """ + Test Extraction + """ + config_dict = { + f'extractor.openlineage_tablelineage.{OpenLineageTableLineageExtractor.TABLE_LINEAGE_FILE_LOCATION}': + 'example/sample_data/openlineage/sample_openlineage_events.ndjson', + f'extractor.openlineage_tablelineage.{OpenLineageTableLineageExtractor.CLUSTER_NAME}': 'datalab', + + } + self.conf = ConfigFactory.from_dict(config_dict) + extractor = OpenLineageTableLineageExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = extractor.extract() + + self.assertEqual('hive://datalab.test/source_table1', result.table_key) + self.assertEqual(['hive://datalab.test/destination_table'], result.downstream_deps) + + result2 = extractor.extract() + + self.assertEqual('hive://datalab.test/source_table1', result2.table_key) + self.assertEqual(['hive://datalab.test/destination_table2'], result2.downstream_deps) + + result3 = extractor.extract() + self.assertEqual('hive://datalab.test/source_table1', result3.table_key) + self.assertEqual(['hive://datalab.test/destination_table'], result3.downstream_deps) + + result4 = extractor.extract() + self.assertEqual('hive://datalab.test/source_table1', result4.table_key) + self.assertEqual(['hive://datalab.test/destination_table4'], result4.downstream_deps) + + result5 = extractor.extract() + self.assertEqual('hive://datalab.test/source_table2', result5.table_key) + self.assertEqual(['hive://datalab.test/destination_table7'], result5.downstream_deps) + + result6 = extractor.extract() + self.assertEqual('hive://datalab.test/source_table3', result6.table_key) + self.assertEqual(['hive://datalab.test/destination_table11'], result6.downstream_deps) + + result7 = extractor.extract() + self.assertEqual('hive://datalab.test/source_table3', result7.table_key) + self.assertEqual(['hive://datalab.test/destination_table10'], result7.downstream_deps) diff --git a/databuilder/tests/unit/extractor/test_oracle_metadata_extractor.py b/databuilder/tests/unit/extractor/test_oracle_metadata_extractor.py new file mode 100644 index 0000000000..8234c2cb14 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_oracle_metadata_extractor.py @@ -0,0 +1,283 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.oracle_metadata_extractor import OracleMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestOracleMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + OracleMetadataExtractor.CLUSTER_KEY: 'MY_CLUSTER', + OracleMetadataExtractor.DATABASE_KEY: 'oracle' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': + self.conf[OracleMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('oracle', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'cluster': + self.conf[OracleMetadataExtractor.CLUSTER_KEY] + } + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'cluster': + self.conf[OracleMetadataExtractor.CLUSTER_KEY] + } + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'cluster': + self.conf[OracleMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('oracle', + self.conf[OracleMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('oracle', + self.conf[OracleMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('oracle', + self.conf[OracleMetadataExtractor.CLUSTER_KEY], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestOracleMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + OracleMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestOracleMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + OracleMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestOracleMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = OracleMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(OracleMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_pandas_profiling_column_stats_extractor.py b/databuilder/tests/unit/extractor/test_pandas_profiling_column_stats_extractor.py new file mode 100644 index 0000000000..85f23b0a1e --- /dev/null +++ b/databuilder/tests/unit/extractor/test_pandas_profiling_column_stats_extractor.py @@ -0,0 +1,90 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any + +from mock import MagicMock +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.pandas_profiling_column_stats_extractor import PandasProfilingColumnStatsExtractor +from databuilder.models.table_stats import TableColumnStats + + +class TestPandasProfilingColumnStatsExtractor(unittest.TestCase): + report_data = { + 'analysis': { + 'date_start': '2021-05-17 10:10:15.142044' + }, + 'variables': { + 'column_1': { + 'mean': 5.120, + 'max': 15.23456 + }, + 'column_2': { + 'mean': 10 + } + } + } + + def setUp(self) -> None: + config = {'extractor.pandas_profiling.file_path': None} + config = ConfigFactory.from_dict({**config, **self._common_params()}) + + self.config = config + + @staticmethod + def _common_params() -> Any: + return {'extractor.pandas_profiling.table_name': 'table_name', + 'extractor.pandas_profiling.schema_name': 'schema_name', + 'extractor.pandas_profiling.database_name': 'database_name', + 'extractor.pandas_profiling.cluster_name': 'cluster_name'} + + def _get_extractor(self) -> Any: + extractor = PandasProfilingColumnStatsExtractor() + extractor.init(Scoped.get_scoped_conf(conf=self.config, scope=extractor.get_scope())) + + return extractor + + def test_extractor(self) -> None: + extractor = self._get_extractor() + + extractor._load_report = MagicMock(return_value=self.report_data) + + common = { + 'db': self._common_params().get('extractor.pandas_profiling.database_name'), + 'schema': self._common_params().get('extractor.pandas_profiling.schema_name'), + 'table_name': self._common_params().get('extractor.pandas_profiling.table_name'), + 'cluster': self._common_params().get('extractor.pandas_profiling.cluster_name'), + 'start_epoch': '1621246215', + 'end_epoch': '0' + } + compare_params = {'table', 'schema', 'db', 'col_name', 'start_epoch', + 'end_epoch', 'cluster', 'stat_type', 'stat_val'} + expected = [ + {x: spec[x] for x in compare_params if x in spec} for spec in + [ + TableColumnStats(**{**dict(stat_name='Mean', stat_val='5.12', col_name='column_1'), **common}).__dict__, + TableColumnStats( + **{**dict(stat_name='Maximum', stat_val='15.235', col_name='column_1'), **common}).__dict__, + TableColumnStats(**{**dict(stat_name='Mean', stat_val='10.0', col_name='column_2'), **common}).__dict__, + ] + ] + + result = [] + + while True: + stat = extractor.extract() + + if stat: + result.append(stat) + else: + break + + result_spec = [{x: spec.__dict__[x] for x in compare_params if x in spec.__dict__} for spec in result] + + for r in result: + self.assertIsInstance(r, TableColumnStats) + + self.assertListEqual(expected, result_spec) diff --git a/databuilder/tests/unit/extractor/test_postgres_metadata_extractor.py b/databuilder/tests/unit/extractor/test_postgres_metadata_extractor.py new file mode 100644 index 0000000000..3e7d5f3e7b --- /dev/null +++ b/databuilder/tests/unit/extractor/test_postgres_metadata_extractor.py @@ -0,0 +1,310 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.postgres_metadata_extractor import PostgresMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestPostgresMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + PostgresMetadataExtractor.CLUSTER_KEY: 'MY_CLUSTER', + PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False, + PostgresMetadataExtractor.DATABASE_KEY: 'postgres' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': + self.conf[PostgresMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('postgres', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'cluster': + self.conf[PostgresMetadataExtractor.CLUSTER_KEY] + } + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'cluster': + self.conf[PostgresMetadataExtractor.CLUSTER_KEY] + } + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'cluster': + self.conf[PostgresMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('postgres', + self.conf[PostgresMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('postgres', + self.conf[PostgresMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('postgres', + self.conf[PostgresMetadataExtractor.CLUSTER_KEY], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestPostgresMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + PostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestPostgresMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + PostgresMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestPostgresMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(PostgresMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestPostgresMetadataExtractorTableCatalogEnabled(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + PostgresMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PostgresMetadataExtractor() + extractor.init(self.conf) + self.assertTrue('current_database()' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_presto_view_metadata_extractor.py b/databuilder/tests/unit/extractor/test_presto_view_metadata_extractor.py new file mode 100644 index 0000000000..5646f56974 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_presto_view_metadata_extractor.py @@ -0,0 +1,91 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import json +import logging +import unittest + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.presto_view_metadata_extractor import PrestoViewMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestPrestoViewMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = PrestoViewMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_multiple_views(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + + columns1 = {'columns': [{'name': 'xyz', 'type': 'varchar'}, + {'name': 'xyy', 'type': 'double'}, + {'name': 'aaa', 'type': 'int'}, + {'name': 'ab', 'type': 'varchar'}]} + + columns2 = {'columns': [{'name': 'xyy', 'type': 'varchar'}, + {'name': 'ab', 'type': 'double'}, + {'name': 'aaa', 'type': 'int'}, + {'name': 'xyz', 'type': 'varchar'}]} + + sql_execute.return_value = [ + {'tbl_id': 2, + 'schema': 'test_schema2', + 'name': 'test_view2', + 'tbl_type': 'virtual_view', + 'view_original_text': base64.b64encode(json.dumps(columns2).encode()).decode("utf-8")}, + {'tbl_id': 1, + 'schema': 'test_schema1', + 'name': 'test_view1', + 'tbl_type': 'virtual_view', + 'view_original_text': base64.b64encode(json.dumps(columns1).encode()).decode("utf-8")}, + ] + + extractor = PrestoViewMetadataExtractor() + extractor.init(self.conf) + actual_first_view = extractor.extract() + expected_first_view = TableMetadata('presto', 'gold', 'test_schema2', 'test_view2', None, + [ColumnMetadata(u'xyy', None, u'varchar', 0), + ColumnMetadata(u'ab', None, u'double', 1), + ColumnMetadata(u'aaa', None, u'int', 2), + ColumnMetadata(u'xyz', None, u'varchar', 3)], + True) + self.assertEqual(expected_first_view.__repr__(), actual_first_view.__repr__()) + + actual_second_view = extractor.extract() + expected_second_view = TableMetadata('presto', 'gold', 'test_schema1', 'test_view1', None, + [ColumnMetadata(u'xyz', None, u'varchar', 0), + ColumnMetadata(u'xyy', None, u'double', 1), + ColumnMetadata(u'aaa', None, u'int', 2), + ColumnMetadata(u'ab', None, u'varchar', 3)], + True) + self.assertEqual(expected_second_view.__repr__(), actual_second_view.__repr__()) + + self.assertIsNone(extractor.extract()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_redshift_metadata_extractor.py b/databuilder/tests/unit/extractor/test_redshift_metadata_extractor.py new file mode 100644 index 0000000000..4ebd770fb4 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_redshift_metadata_extractor.py @@ -0,0 +1,158 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.redshift_metadata_extractor import RedshiftMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestRedshiftMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + RedshiftMetadataExtractor.CLUSTER_KEY: 'MY_CLUSTER', + RedshiftMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False, + RedshiftMetadataExtractor.DATABASE_KEY: 'redshift' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = RedshiftMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': + self.conf[RedshiftMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = RedshiftMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('redshift', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestRedshiftMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + RedshiftMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test extraction sql properly includes where suffix + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = RedshiftMetadataExtractor() + extractor.init(self.conf) + expected_where_clause = f'where {self.where_clause_suffix}' + + self.assertTrue(expected_where_clause in extractor.sql_stmt) + + +class TestRedshiftMetadataExtractorWithLegacyWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + RedshiftMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test extraction sql properly includes where suffix using legacy specification + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = RedshiftMetadataExtractor() + extractor.init(self.conf) + + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_salesforce_extractor.py b/databuilder/tests/unit/extractor/test_salesforce_extractor.py new file mode 100644 index 0000000000..be7b691701 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_salesforce_extractor.py @@ -0,0 +1,198 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from collections import OrderedDict +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.salesforce_extractor import SalesForceExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + +METADATA = { + "Account": { + "fields": [ + {"name": "Id", "inlineHelpText": "The Account Id", "type": "id"}, + {"name": "isDeleted", "inlineHelpText": "Deleted?", "type": "bool"}, + ] + }, + "Profile": { + "fields": [ + {"name": "Id", "inlineHelpText": "The Profile Id", "type": "id"}, + { + "name": "Business", + "inlineHelpText": "Important Bizness", + "type": "string", + }, + ] + }, +} + + +class MockSalesForce: + def __init__(self) -> None: + pass + + def describe(self) -> Dict: + return { + "encoding": "UTF-8", + "maxBatchSize": 200, + "sobjects": [ + OrderedDict( + [ + ("activateable", False), + ("createable", False), + ("custom", False), + ("customSetting", False), + ("deletable", False), + ("deprecatedAndHidden", False), + ("feedEnabled", False), + ("hasSubtypes", False), + ("isSubtype", False), + ("keyPrefix", None), + ("label", object_name), + ("labelPlural", object_name), + ("layoutable", False), + ("mergeable", False), + ("mruEnabled", False), + ("name", object_name), + ("queryable", True), + ("replicateable", False), + ("retrieveable", True), + ("searchable", False), + ("triggerable", False), + ("undeletable", False), + ("updateable", False), + ( + "urls", + OrderedDict( + [ + ( + "rowTemplate", + f"/services/data/v42.0/sobjects/{object_name}/" + "{ID}", + ), + ( + "describe", + f"/services/data/v42.0/sobjects/{object_name}/describe", + ), + ( + "sobject", + f"/services/data/v42.0/sobjects/{object_name}", + ), + ] + ), + ), + ] + ) + for object_name in METADATA.keys() + ], + } + + def restful(self, path: str) -> Dict: + object_name = path.split("/")[1] + return METADATA[object_name] + + +class TestSalesForceExtractor(unittest.TestCase): + def setUp(self) -> None: + self.config = { + f"extractor.salesforce_metadata.{SalesForceExtractor.USERNAME_KEY}": "user", + f"extractor.salesforce_metadata.{SalesForceExtractor.PASSWORD_KEY}": "password", + f"extractor.salesforce_metadata.{SalesForceExtractor.SECURITY_TOKEN_KEY}": "token", + f"extractor.salesforce_metadata.{SalesForceExtractor.SCHEMA_KEY}": "default", + f"extractor.salesforce_metadata.{SalesForceExtractor.CLUSTER_KEY}": "gold", + f"extractor.salesforce_metadata.{SalesForceExtractor.DATABASE_KEY}": "salesforce", + } + + @patch("databuilder.extractor.salesforce_extractor.Salesforce") + def test_extraction_one_object(self, mock_salesforce: Any) -> None: + mock_salesforce.return_value = MockSalesForce() + config_dict: Dict = { + f"extractor.salesforce_metadata.{SalesForceExtractor.OBJECT_NAMES_KEY}": [ + "Account" + ], + **self.config, + } + conf = ConfigFactory.from_dict(config_dict) + + mock_salesforce.return_value = MockSalesForce() + extractor = SalesForceExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + result = extractor.extract() + self.assertIsInstance(result, TableMetadata) + + expected = TableMetadata( + "salesforce", + "gold", + "default", + "Account", + None, + [ + ColumnMetadata("Id", "The Account Id", "id", 0, []), + ColumnMetadata("isDeleted", "Deleted?", "bool", 1, []), + ], + False, + [], + ) + + self.assertEqual(expected.__repr__(), result.__repr__()) + + self.assertIsNone(extractor.extract()) + + @patch("databuilder.extractor.salesforce_extractor.Salesforce") + def test_extraction_multiple_objects(self, mock_salesforce: Any) -> None: + mock_salesforce.return_value = MockSalesForce() + config_dict: Dict = { + f"extractor.salesforce_metadata.{SalesForceExtractor.OBJECT_NAMES_KEY}": [ + "Account", + "Profile", + ], + **self.config, + } + conf = ConfigFactory.from_dict(config_dict) + + mock_salesforce.return_value = MockSalesForce() + extractor = SalesForceExtractor() + extractor.init(Scoped.get_scoped_conf(conf=conf, scope=extractor.get_scope())) + + results = [extractor.extract(), extractor.extract()] + for result in results: + self.assertIsInstance(result, TableMetadata) + + expecteds = [ + TableMetadata( + "salesforce", + "gold", + "default", + "Account", + None, + [ + ColumnMetadata("Id", "The Account Id", "id", 0, []), + ColumnMetadata("isDeleted", "Deleted?", "bool", 1, []), + ], + False, + [], + ), + TableMetadata( + "salesforce", + "gold", + "default", + "Profile", + None, + [ + # These columns are sorted alphabetically + ColumnMetadata("Business", "Important Bizness", "string", 0, []), + ColumnMetadata("Id", "The Profile Id", "id", 1, []), + ], + False, + [], + ), + ] + + for result, expected in zip(results, expecteds): + self.assertEqual(expected.__repr__(), result.__repr__()) + + self.assertIsNone(extractor.extract()) diff --git a/databuilder/tests/unit/extractor/test_snowflake_metadata_extractor.py b/databuilder/tests/unit/extractor/test_snowflake_metadata_extractor.py new file mode 100644 index 0000000000..e25758cd21 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_snowflake_metadata_extractor.py @@ -0,0 +1,399 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.snowflake_metadata_extractor import SnowflakeMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestSnowflakeMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}': 'MY_CLUSTER', + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME}': False, + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY}': 'prod' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': self.conf[f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'is_view': 'false' + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'number', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'number', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp_ltz', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('snowflake', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'number', 0), + ColumnMetadata('col_id2', 'description of id2', 'number', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', + 'timestamp_ltz', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'cluster': + self.conf[f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'is_view': 'nottrue' + } + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'cluster': + self.conf[f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'is_view': 'false' + } + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'cluster': + self.conf[f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'is_view': 'true' + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'number', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'number', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp_ltz', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('snowflake', + self.conf[ + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'number', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'number', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', + 'timestamp_ltz', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('snowflake', + self.conf[ + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('snowflake', + self.conf[ + f'extractor.snowflake_metadata.{SnowflakeMetadataExtractor.CLUSTER_KEY}'], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], + True) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestSnowflakeMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestSnowflakeMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + SnowflakeMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestSnowflakeMetadataExtractorDefaultSnowflakeDatabaseKey(unittest.TestCase): + # test when SNOWFLAKE_DATABASE_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.snowflake_database_key = "not_prod" + + config_dict = { + SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY: self.snowflake_database_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.snowflake_database_key in extractor.sql_stmt) + + +class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase): + # test when DATABASE_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.database_key = 'not_snowflake' + + config_dict = { + SnowflakeMetadataExtractor.DATABASE_KEY: self.database_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertFalse(self.database_key in extractor.sql_stmt) + + def test_extraction_with_database_specified(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + + sql_execute.return_value = [ + {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': 'MY_CLUSTER', + 'is_view': 'false', + 'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 0} + ] + + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata( + self.database_key, 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('ds', None, 'varchar', 0)] + ) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + +class TestSnowflakeMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(SnowflakeMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestSnowflakeMetadataExtractorTableCatalogEnabled(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + SnowflakeMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeMetadataExtractor() + extractor.init(self.conf) + self.assertTrue('table_catalog' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_snowflake_table_last_updated_extractor.py b/databuilder/tests/unit/extractor/test_snowflake_table_last_updated_extractor.py new file mode 100644 index 0000000000..e6308079d1 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_snowflake_table_last_updated_extractor.py @@ -0,0 +1,297 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.snowflake_table_last_updated_extractor import SnowflakeTableLastUpdatedExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_last_updated import TableLastUpdated + + +class TestSnowflakeTableLastUpdatedExtractor(unittest.TestCase): + def setUp(self) -> None: + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + f'extractor.snowflake_table_last_updated.{SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY}': + 'MY_CLUSTER', + f'extractor.snowflake_table_last_updated.{SnowflakeTableLastUpdatedExtractor.USE_CATALOG_AS_CLUSTER_NAME}': + False, + f'extractor.snowflake_table_last_updated.{SnowflakeTableLastUpdatedExtractor.SNOWFLAKE_DATABASE_KEY}': + 'prod' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertIsNone(results) + + def test_extraction_with_single_result(self) -> None: + """ + Test Extraction with default cluster and database and with one table as result + """ + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + sql_execute.return_value = [ + {'schema': 'test_schema', + 'table_name': 'test_table', + 'last_updated_time': 1000, + 'cluster': self.conf[ + f'extractor.snowflake_table_last_updated.{SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY}'], + } + ] + + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + actual = extractor.extract() + + expected = TableLastUpdated(schema='test_schema', table_name='test_table', + last_updated_time_epoch=1000, + db='snowflake', cluster='MY_CLUSTER') + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + """ + Test Extraction with default cluster and database and with multiple tables as result + """ + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + + default_cluster = self.conf[ + f'extractor.snowflake_table_last_updated.{SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY}'] + + table = {'schema': 'test_schema1', + 'table_name': 'test_table1', + 'last_updated_time': 1000, + 'cluster': default_cluster + } + + table1 = {'schema': 'test_schema1', + 'table_name': 'test_table2', + 'last_updated_time': 2000, + 'cluster': default_cluster + } + + table2 = {'schema': 'test_schema2', + 'table_name': 'test_table3', + 'last_updated_time': 3000, + 'cluster': default_cluster + } + + sql_execute.return_value = [table, table1, table2] + + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + + expected = TableLastUpdated(schema='test_schema1', table_name='test_table1', + last_updated_time_epoch=1000, + db='snowflake', cluster='MY_CLUSTER') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableLastUpdated(schema='test_schema1', table_name='test_table2', + last_updated_time_epoch=2000, + db='snowflake', cluster='MY_CLUSTER') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableLastUpdated(schema='test_schema2', table_name='test_table3', + last_updated_time_epoch=3000, + db='snowflake', cluster='MY_CLUSTER') + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + + +class TestSnowflakeTableLastUpdatedExtractorWithWhereClause(unittest.TestCase): + """ + Test 'where_clause' config key in extractor + """ + + def setUp(self) -> None: + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + SnowflakeTableLastUpdatedExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + test where clause in extractor sql statement + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestSnowflakeTableLastUpdatedExtractorClusterKeyNoTableCatalog(unittest.TestCase): + """ + Test with 'USE_CATALOG_AS_CLUSTER_NAME' is false and 'CLUSTER_KEY' is specified + """ + + def setUp(self) -> None: + self.cluster_key = "not_master" + + config_dict = { + SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + SnowflakeTableLastUpdatedExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test cluster_key in extractor sql stmt + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestSnowflakeTableLastUpdatedExtractorDefaultSnowflakeDatabaseKey(unittest.TestCase): + """ + Test with SNOWFLAKE_DATABASE_KEY config specified + """ + + def setUp(self) -> None: + self.snowflake_database_key = "not_prod" + + config_dict = { + SnowflakeTableLastUpdatedExtractor.SNOWFLAKE_DATABASE_KEY: self.snowflake_database_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test SNOWFLAKE_DATABASE_KEY in extractor sql stmt + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertTrue(self.snowflake_database_key in extractor.sql_stmt) + + +class TestSnowflakeTableLastUpdatedExtractorDefaultDatabaseKey(unittest.TestCase): + """ + Test with DATABASE_KEY config specified + """ + + def setUp(self) -> None: + self.database_key = 'not_snowflake' + + config_dict = { + SnowflakeTableLastUpdatedExtractor.DATABASE_KEY: self.database_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test DATABASE_KEY in extractor sql stmt + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertFalse(self.database_key in extractor.sql_stmt) + + def test_extraction_with_database_specified(self) -> None: + """ + Test DATABASE_KEY in extractor result + """ + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + + sql_execute.return_value = [ + {'schema': 'test_schema', + 'table_name': 'test_table', + 'last_updated_time': 1000, + 'cluster': 'MY_CLUSTER', + } + ] + + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableLastUpdated(schema='test_schema', table_name='test_table', + last_updated_time_epoch=1000, + db=self.database_key, cluster='MY_CLUSTER') + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + +class TestSnowflakeTableLastUpdatedExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + """ + Test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + """ + + def setUp(self) -> None: + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + SnowflakeTableLastUpdatedExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test cluster name in extract sql stmt + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertTrue(SnowflakeTableLastUpdatedExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestSnowflakeTableLastUpdatedExtractorTableCatalogEnabled(unittest.TestCase): + """ + Test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + """ + + def setUp(self) -> None: + self.cluster_key = "not_master" + + config_dict = { + SnowflakeTableLastUpdatedExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + SnowflakeTableLastUpdatedExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Ensure catalog is used as cluster in extract sql stmt + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = SnowflakeTableLastUpdatedExtractor() + extractor.init(self.conf) + self.assertTrue('table_catalog' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_sql_alchemy_extractor.py b/databuilder/tests/unit/extractor/test_sql_alchemy_extractor.py new file mode 100644 index 0000000000..32b4ed6e1e --- /dev/null +++ b/databuilder/tests/unit/extractor/test_sql_alchemy_extractor.py @@ -0,0 +1,149 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any, Dict + +from mock import patch +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor + + +class TestSqlAlchemyExtractor(unittest.TestCase): + + def setUp(self) -> None: + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;' + } + self.conf = ConfigFactory.from_dict(config_dict) + + @patch.object(SQLAlchemyExtractor, '_get_connection') + def test_extraction_with_empty_query_result(self: Any, + mock_method: Any) -> None: + """ + Test Extraction with empty result from query + """ + extractor = SQLAlchemyExtractor() + extractor.results = [''] + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + results = extractor.extract() + self.assertEqual(results, '') + + @patch.object(SQLAlchemyExtractor, '_get_connection') + def test_extraction_with_single_query_result(self: Any, + mock_method: Any) -> None: + """ + Test Extraction from single result from query + """ + extractor = SQLAlchemyExtractor() + extractor.results = [('test_result')] + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + results = extractor.extract() + self.assertEqual(results, 'test_result') + + @patch.object(SQLAlchemyExtractor, '_get_connection') + def test_extraction_with_multiple_query_result(self: Any, + mock_method: Any) -> None: + """ + Test Extraction from list of results from query + """ + extractor = SQLAlchemyExtractor() + extractor.results = ['test_result', 'test_result2', 'test_result3'] + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + result = [extractor.extract() for _ in range(3)] + + self.assertEqual(len(result), 3) + self.assertEqual(result, + ['test_result', 'test_result2', 'test_result3']) + + @patch.object(SQLAlchemyExtractor, '_get_connection') + def test_extraction_with_model_class(self: Any, mock_method: Any) -> None: + """ + Test Extraction using model class + """ + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;', + 'extractor.sqlalchemy.model_class': + 'tests.unit.extractor.test_sql_alchemy_extractor.TableMetadataResult' + } + self.conf = ConfigFactory.from_dict(config_dict) + + extractor = SQLAlchemyExtractor() + extractor.results = [dict(database='test_database', + schema='test_schema', + name='test_table', + description='test_description', + column_name='test_column_name', + column_type='test_column_type', + column_comment='test_column_comment', + owner='test_owner')] + + extractor.init(Scoped.get_scoped_conf(conf=self.conf, + scope=extractor.get_scope())) + + result = extractor.extract() + + self.assertIsInstance(result, TableMetadataResult) + self.assertEqual(result.name, 'test_table') + + @patch('databuilder.extractor.sql_alchemy_extractor.create_engine') + def test_get_connection(self: Any, mock_method: Any) -> None: + """ + Test that configs are passed through correctly to the _get_connection method + """ + extractor = SQLAlchemyExtractor() + config_dict: Dict[str, Any] = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;' + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={}) + + extractor = SQLAlchemyExtractor() + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;', + 'extractor.sqlalchemy.connect_args': {"protocol": "https"}, + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={"protocol": "https"}) + + +class TableMetadataResult: + """ + Table metadata result model. + SQL result has one row per column + """ + + def __init__(self, + database: str, + schema: str, + name: str, + description: str, + column_name: str, + column_type: str, + column_comment: str, + owner: str + ) -> None: + self.database = database + self.schema = schema + self.name = name + self.description = description + self.column_name = column_name + self.column_type = column_type + self.column_comment = column_comment + self.owner = owner diff --git a/databuilder/tests/unit/extractor/test_sql_server_metadata_extractor.py b/databuilder/tests/unit/extractor/test_sql_server_metadata_extractor.py new file mode 100644 index 0000000000..10868463a1 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_sql_server_metadata_extractor.py @@ -0,0 +1,327 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.mssql_metadata_extractor import MSSQLMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestMSSQLMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}': 'MY_CLUSTER', + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME}': False, + f'extractor.mssql_metadata.{MSSQLMetadataExtractor.DATABASE_KEY}': 'mssql' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema_name': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata( + 'mssql', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], + False, ['test_schema']) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema_name': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + table1 = {'schema_name': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + table2 = {'schema_name': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'cluster': + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata( + 'mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + + ], + False, ['test_schema1'] + ) + + actual = extractor.extract().__repr__() + self.assertEqual(expected.__repr__(), actual) + + expected = TableMetadata( + 'mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)], + False, ['test_schema1']) + actual = extractor.extract().__repr__() + + self.assertEqual(expected.__repr__(), actual) + + expected = TableMetadata( + 'mssql', + self.conf[f'extractor.mssql_metadata.{MSSQLMetadataExtractor.CLUSTER_KEY}'], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)], + False, ['test_schema2']) + actual = extractor.extract().__repr__() + self.assertEqual(expected.__repr__(), actual) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestMSSQLMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + MSSQLMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + MSSQLMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': + 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(MSSQLMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestMSSQLMetadataExtractorTableCatalogEnabled(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + MSSQLMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + MSSQLMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = MSSQLMetadataExtractor() + extractor.init(self.conf) + self.assertTrue('DB_NAME()' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/test_teradata_metadata_extractor.py b/databuilder/tests/unit/extractor/test_teradata_metadata_extractor.py new file mode 100644 index 0000000000..3494a0f0d3 --- /dev/null +++ b/databuilder/tests/unit/extractor/test_teradata_metadata_extractor.py @@ -0,0 +1,310 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any, Dict + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.extractor.teradata_metadata_extractor import TeradataMetadataExtractor +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata + + +class TestTeradataMetadataExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + TeradataMetadataExtractor.CLUSTER_KEY: 'MY_CLUSTER', + TeradataMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False, + TeradataMetadataExtractor.DATABASE_KEY: 'teradata' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema', + 'name': 'test_table', + 'description': 'a table for testing', + 'td_cluster': + self.conf[TeradataMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table) + ] + + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata('teradata', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', + [ColumnMetadata('col_id1', 'description of id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection: + connection = MagicMock() + mock_connection.return_value = connection + sql_execute = MagicMock() + connection.execute = sql_execute + table = {'schema': 'test_schema1', + 'name': 'test_table1', + 'description': 'test table 1', + 'td_cluster': + self.conf[TeradataMetadataExtractor.CLUSTER_KEY] + } + + table1 = {'schema': 'test_schema1', + 'name': 'test_table2', + 'description': 'test table 2', + 'td_cluster': + self.conf[TeradataMetadataExtractor.CLUSTER_KEY] + } + + table2 = {'schema': 'test_schema2', + 'name': 'test_table3', + 'description': 'test table 3', + 'td_cluster': + self.conf[TeradataMetadataExtractor.CLUSTER_KEY] + } + + sql_execute.return_value = [ + self._union( + {'col_name': 'col_id1', + 'col_type': 'bigint', + 'col_description': 'description of col_id1', + 'col_sort_order': 0}, table), + self._union( + {'col_name': 'col_id2', + 'col_type': 'bigint', + 'col_description': 'description of col_id2', + 'col_sort_order': 1}, table), + self._union( + {'col_name': 'is_active', + 'col_type': 'boolean', + 'col_description': None, + 'col_sort_order': 2}, table), + self._union( + {'col_name': 'source', + 'col_type': 'varchar', + 'col_description': 'description of source', + 'col_sort_order': 3}, table), + self._union( + {'col_name': 'etl_created_at', + 'col_type': 'timestamp', + 'col_description': 'description of etl_created_at', + 'col_sort_order': 4}, table), + self._union( + {'col_name': 'ds', + 'col_type': 'varchar', + 'col_description': None, + 'col_sort_order': 5}, table), + self._union( + {'col_name': 'col_name', + 'col_type': 'varchar', + 'col_description': 'description of col_name', + 'col_sort_order': 0}, table1), + self._union( + {'col_name': 'col_name2', + 'col_type': 'varchar', + 'col_description': 'description of col_name2', + 'col_sort_order': 1}, table1), + self._union( + {'col_name': 'col_id3', + 'col_type': 'varchar', + 'col_description': 'description of col_id3', + 'col_sort_order': 0}, table2), + self._union( + {'col_name': 'col_name3', + 'col_type': 'varchar', + 'col_description': 'description of col_name3', + 'col_sort_order': 1}, table2) + ] + + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + + expected = TableMetadata('teradata', + self.conf[TeradataMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table1', 'test table 1', + [ColumnMetadata('col_id1', 'description of col_id1', 'bigint', 0), + ColumnMetadata('col_id2', 'description of col_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('teradata', + self.conf[TeradataMetadataExtractor.CLUSTER_KEY], + 'test_schema1', 'test_table2', 'test table 2', + [ColumnMetadata('col_name', 'description of col_name', 'varchar', 0), + ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata('teradata', + self.conf[TeradataMetadataExtractor.CLUSTER_KEY], + 'test_schema2', 'test_table3', 'test table 3', + [ColumnMetadata('col_id3', 'description of col_id3', 'varchar', 0), + ColumnMetadata('col_name3', 'description of col_name3', + 'varchar', 1)]) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) + + def _union(self, + target: Dict[Any, Any], + extra: Dict[Any, Any]) -> Dict[Any, Any]: + target.update(extra) + return target + + +class TestTeradataMetadataExtractorWithWhereClause(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.where_clause_suffix = """ + where table_schema in ('public') and table_name = 'movies' + """ + + config_dict = { + TeradataMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION' + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.where_clause_suffix in extractor.sql_stmt) + + +class TestTeradataMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + TeradataMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + TeradataMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(self.cluster_key in extractor.sql_stmt) + + +class TestTeradataMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is false and CLUSTER_KEY is NOT specified + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + config_dict = { + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + TeradataMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: False + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + self.assertTrue(TeradataMetadataExtractor.DEFAULT_CLUSTER_NAME in extractor.sql_stmt) + + +class TestTeradataMetadataExtractorTableCatalogEnabled(unittest.TestCase): + # test when USE_CATALOG_AS_CLUSTER_NAME is true (CLUSTER_KEY should be ignored) + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self.cluster_key = "not_master" + + config_dict = { + TeradataMetadataExtractor.CLUSTER_KEY: self.cluster_key, + f'extractor.sqlalchemy.{SQLAlchemyExtractor.CONN_STRING}': 'TEST_CONNECTION', + TeradataMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME: True + } + self.conf = ConfigFactory.from_dict(config_dict) + + def test_sql_statement(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(SQLAlchemyExtractor, '_get_connection'): + extractor = TeradataMetadataExtractor() + extractor.init(self.conf) + self.assertTrue('current_database()' in extractor.sql_stmt) + self.assertFalse(self.cluster_key in extractor.sql_stmt) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/extractor/user/__init__.py b/databuilder/tests/unit/extractor/user/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/user/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/user/bamboohr/__init__.py b/databuilder/tests/unit/extractor/user/bamboohr/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/extractor/user/bamboohr/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/extractor/user/bamboohr/test_bamboohr_user_extractor.py b/databuilder/tests/unit/extractor/user/bamboohr/test_bamboohr_user_extractor.py new file mode 100644 index 0000000000..68ae02482f --- /dev/null +++ b/databuilder/tests/unit/extractor/user/bamboohr/test_bamboohr_user_extractor.py @@ -0,0 +1,45 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import io +import os +import unittest + +import responses +from pyhocon import ConfigFactory + +from databuilder.extractor.user.bamboohr.bamboohr_user_extractor import BamboohrUserExtractor +from databuilder.models.user import User + + +class TestBamboohrUserExtractor(unittest.TestCase): + @responses.activate + def test_parse_testdata(self) -> None: + bhr = BamboohrUserExtractor() + bhr.init(ConfigFactory.from_dict({'api_key': 'api_key', 'subdomain': 'amundsen'})) + + testdata_xml = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '../../../resources/extractor/user/bamboohr/testdata.xml' + ) + + with io.open(testdata_xml) as testdata: + responses.add(responses.GET, bhr._employee_directory_uri(), body=testdata.read()) + + expected = User( + email='roald@amundsen.io', + first_name='Roald', + last_name='Amundsen', + name='Roald Amundsen', + team_name='508 Corporate Marketing', + role_name='Antarctic Explorer', + ) + + actual_users = list(bhr._get_extract_iter()) + + self.assertEqual(1, len(actual_users)) + self.assertEqual(repr(expected), repr(actual_users[0])) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/filesystem/__init__.py b/databuilder/tests/unit/filesystem/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/filesystem/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/filesystem/test_filesystem.py b/databuilder/tests/unit/filesystem/test_filesystem.py new file mode 100644 index 0000000000..491d074f6d --- /dev/null +++ b/databuilder/tests/unit/filesystem/test_filesystem.py @@ -0,0 +1,50 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from datetime import datetime + +from mock import MagicMock +from pyhocon import ConfigFactory +from pytz import UTC + +from databuilder.filesystem.filesystem import FileSystem +from databuilder.filesystem.metadata import FileMetadata + + +class TestFileSystem(unittest.TestCase): + + def test_is_file(self) -> None: + dask_fs = MagicMock() + dask_fs.ls = MagicMock(return_value=['/foo/bar']) + + fs = FileSystem() + conf = ConfigFactory.from_dict({FileSystem.DASK_FILE_SYSTEM: dask_fs}) + fs.init(conf=conf) + + self.assertTrue(fs.is_file('/foo/bar')) + + dask_fs.ls = MagicMock(return_value=['bar', 'baz']) + + fs = FileSystem() + conf = ConfigFactory.from_dict({FileSystem.DASK_FILE_SYSTEM: dask_fs}) + fs.init(conf=conf) + + self.assertFalse(fs.is_file('foo')) + + def test_info(self) -> None: + dask_fs = MagicMock() + dask_fs.info = MagicMock(return_value={'LastModified': datetime(2018, 8, 14, 4, 12, 3, tzinfo=UTC), + 'Size': 15093}) + fs = FileSystem() + conf = ConfigFactory.from_dict({FileSystem.DASK_FILE_SYSTEM: dask_fs}) + fs.init(conf=conf) + metadata = fs.info('/foo/bar') + + expected = FileMetadata(path='/foo/bar', last_updated=datetime(2018, 8, 14, 4, 12, 3, tzinfo=UTC), size=15093) + + self.assertEqual(metadata.__repr__(), expected.__repr__()) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/loader/__init__.py b/databuilder/tests/unit/loader/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/loader/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/loader/test_file_system_atlas_csv_loader.py b/databuilder/tests/unit/loader/test_file_system_atlas_csv_loader.py new file mode 100644 index 0000000000..706cacf72c --- /dev/null +++ b/databuilder/tests/unit/loader/test_file_system_atlas_csv_loader.py @@ -0,0 +1,94 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import collections +import csv +import logging +import os +import unittest +from operator import itemgetter +from os import listdir +from os.path import isfile, join +from typing import ( + Any, Callable, Dict, Iterable, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.file_system_atlas_csv_loader import FsAtlasCSVLoader +from tests.unit.models.test_atlas_serializable import ( + Actor, City, Movie, +) + +here = os.path.dirname(__file__) + + +class TestFileSystemAtlasCSVLoader(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + def _make_conf(self, test_name: str) -> ConfigTree: + prefix = '/var/tmp/TestFileSystemAtlasCSVLoader' + + return ConfigFactory.from_dict({ + FsAtlasCSVLoader.ENTITY_DIR_PATH: f'{prefix}/{test_name}/{"entities"}', + FsAtlasCSVLoader.RELATIONSHIP_DIR_PATH: f'{prefix}/{test_name}/{"relationships"}', + FsAtlasCSVLoader.SHOULD_DELETE_CREATED_DIR: True, + }) + + def tearDown(self) -> None: + Job.closer.close() + + def test_load(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + cities = [City('San Diego'), City('Oakland')] + movie = Movie('Top Gun', actors, cities) + + loader = FsAtlasCSVLoader() + + folder = 'movies' + conf = self._make_conf(folder) + + loader.init(conf) + loader.load(movie) + loader.close() + + expected_entity_path = os.path.join(here, f'../resources/fs_atlas_csv_loader/{folder}/entities') + expected_entities = self._get_csv_rows(expected_entity_path, itemgetter('qualifiedName')) + actual_entities = self._get_csv_rows( + conf.get_string(FsAtlasCSVLoader.ENTITY_DIR_PATH), + itemgetter('qualifiedName'), + ) + self.assertEqual(expected_entities, actual_entities) + + expected_rel_path = os.path.join(here, f'../resources/fs_atlas_csv_loader/{folder}/relationships') + expected_relations = self._get_csv_rows( + expected_rel_path, itemgetter( + 'entityQualifiedName1', 'entityQualifiedName2', + ), + ) + actual_relations = self._get_csv_rows( + conf.get_string(FsAtlasCSVLoader.RELATIONSHIP_DIR_PATH), + itemgetter('entityQualifiedName1', 'entityQualifiedName2'), + ) + self.assertEqual(expected_relations, actual_relations) + + def _get_csv_rows( + self, + path: str, + sorting_key_getter: Callable, + ) -> Iterable[Dict[str, Any]]: + files = [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + result = [] + for f in files: + with open(f) as f_input: + reader = csv.DictReader(f_input) + for row in reader: + result.append(collections.OrderedDict(sorted(row.items()))) + print(result) + return sorted(result, key=sorting_key_getter) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/loader/test_file_system_csv_loader.py b/databuilder/tests/unit/loader/test_file_system_csv_loader.py new file mode 100644 index 0000000000..0814abcf39 --- /dev/null +++ b/databuilder/tests/unit/loader/test_file_system_csv_loader.py @@ -0,0 +1,118 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import shutil +import tempfile +import unittest +from typing import List + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.loader.file_system_csv_loader import FileSystemCSVLoader +from tests.unit.extractor.test_sql_alchemy_extractor import TableMetadataResult + + +class TestFileSystemCSVLoader(unittest.TestCase): + + def setUp(self) -> None: + self.temp_dir_path = tempfile.mkdtemp() + self.dest_file_name = f'{self.temp_dir_path}/test_file.csv' + self.file_mode = 'w' + config_dict = {'loader.filesystem.csv.file_path': self.dest_file_name, + 'loader.filesystem.csv.mode': self.file_mode} + self.conf = ConfigFactory.from_dict(config_dict) + + def tearDown(self) -> None: + shutil.rmtree(self.temp_dir_path) + + def _check_results_helper(self, expected: List[str]) -> None: + """ + Helper function to compare results with expected outcome + :param expected: expected result + """ + with open(self.dest_file_name, 'r') as file: + for e in expected: + actual = file.readline().rstrip('\r\n') + self.assertEqual(set(e.split(',')), set(actual.split(','))) + self.assertFalse(file.readline()) + + def test_empty_loading(self) -> None: + """ + Test loading functionality with no data + """ + loader = FileSystemCSVLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + loader.load(None) + loader.close() + + self._check_results_helper(expected=[]) + + def test_loading_with_single_object(self) -> None: + """ + Test Loading functionality with single python object + """ + loader = FileSystemCSVLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + data = TableMetadataResult(database='test_database', + schema='test_schema', + name='test_table', + description='test_description', + column_name='test_column_name', + column_type='test_column_type', + column_comment='test_column_comment', + owner='test_owner') + loader.load(data) + loader.close() + + expected = [ + ','.join(['database', 'schema', 'name', 'description', + 'column_name', 'column_type', 'column_comment', + 'owner']), + ','.join(['test_database', 'test_schema', 'test_table', + 'test_description', 'test_column_name', + 'test_column_type', 'test_column_comment', + 'test_owner']) + ] + + self._check_results_helper(expected=expected) + + def test_loading_with_list_of_objects(self) -> None: + """ + Test Loading functionality with list of objects. + Check to ensure all objects are added to file + """ + loader = FileSystemCSVLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + data = [TableMetadataResult(database='test_database', + schema='test_schema', + name='test_table', + description='test_description', + column_name='test_column_name', + column_type='test_column_type', + column_comment='test_column_comment', + owner='test_owner')] * 5 + + for d in data: + loader.load(d) + loader.close() + + expected = [ + ','.join(['database', 'schema', 'name', 'description', + 'column_name', 'column_type', 'column_comment', + 'owner']) + ] + expected = expected + [ + ','.join(['test_database', 'test_schema', 'test_table', + 'test_description', 'test_column_name', + 'test_column_type', 'test_column_comment', 'test_owner'] + ) + ] * 5 + + self._check_results_helper(expected=expected) diff --git a/databuilder/tests/unit/loader/test_file_system_elasticsearch_json_loader.py b/databuilder/tests/unit/loader/test_file_system_elasticsearch_json_loader.py new file mode 100644 index 0000000000..ef3d886083 --- /dev/null +++ b/databuilder/tests/unit/loader/test_file_system_elasticsearch_json_loader.py @@ -0,0 +1,161 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import shutil +import tempfile +import unittest +from typing import List + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.loader.file_system_elasticsearch_json_loader import FSElasticsearchJSONLoader +from databuilder.models.table_elasticsearch_document import TableESDocument + + +class TestFSElasticsearchJSONLoader(unittest.TestCase): + + def setUp(self) -> None: + self.temp_dir_path = tempfile.mkdtemp() + self.dest_file_name = f'{self.temp_dir_path}/test_file.json' + self.file_mode = 'w' + config_dict = {'loader.filesystem.elasticsearch.file_path': self.dest_file_name, + 'loader.filesystem.elasticsearch.mode': self.file_mode} + self.conf = ConfigFactory.from_dict(config_dict) + + def tearDown(self) -> None: + shutil.rmtree(self.temp_dir_path) + + def _check_results_helper(self, expected: List[str]) -> None: + """ + Helper function to compare results with expected outcome + :param expected: expected result + """ + with open(self.dest_file_name, 'r') as file: + for e in expected: + actual = file.readline().rstrip('\r\n') + self.assertDictEqual(json.loads(e), json.loads(actual)) + self.assertFalse(file.readline()) + + def test_empty_loading(self) -> None: + """ + Test loading functionality with no data + """ + loader = FSElasticsearchJSONLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + loader.load(None) # type: ignore + loader.close() + + self._check_results_helper(expected=[]) + + def test_loading_with_different_object(self) -> None: + """ + Test Loading functionality with a python Dict object + """ + loader = FSElasticsearchJSONLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + data = dict(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table', + key='test_table_key', + last_updated_timestamp=123456789, + description='test_description', + column_names=['test_col1', 'test_col2'], + column_descriptions=['test_comment1', 'test_comment2'], + total_usage=10, + unique_usage=5, + tags=['test_tag1', 'test_tag2'], + programmatic_descriptions=['test']) + + with self.assertRaises(Exception) as context: + loader.load(data) # type: ignore + self.assertIn("Record not of type 'ElasticsearchDocument'!", str(context.exception)) + + loader.close() + + def test_loading_with_single_object(self) -> None: + """ + Test Loading functionality with single python object + """ + loader = FSElasticsearchJSONLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + data = TableESDocument(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table', + key='test_table_key', + last_updated_timestamp=123456789, + description='test_description', + column_names=['test_col1', 'test_col2'], + column_descriptions=['test_comment1', 'test_comment2'], + total_usage=10, + unique_usage=5, + tags=['test_tag1', 'test_tag2'], + badges=['badge1'], + schema_description='schema description', + programmatic_descriptions=['test']) + loader.load(data) + loader.close() + + expected = [ + ('{"key": "test_table_key", "column_descriptions": ["test_comment1", "test_comment2"], ' + '"schema": "test_schema", "database": "test_database", "cluster": "test_cluster", ' + '"column_names": ["test_col1", "test_col2"], "name": "test_table", ' + '"last_updated_timestamp": 123456789, "display_name": "test_schema.test_table", ' + '"description": "test_description", "unique_usage": 5, "total_usage": 10, ' + '"tags": ["test_tag1", "test_tag2"], "schema_description": "schema description", ' + '"programmatic_descriptions": ["test"], ' + '"badges": ["badge1"]}') + ] + + self._check_results_helper(expected=expected) + + def test_loading_with_list_of_objects(self) -> None: + """ + Test Loading functionality with list of objects. + Check to ensure all objects are added to file + """ + loader = FSElasticsearchJSONLoader() + loader.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=loader.get_scope())) + + data = [TableESDocument(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table', + key='test_table_key', + last_updated_timestamp=123456789, + description='test_description', + column_names=['test_col1', 'test_col2'], + column_descriptions=['test_comment1', 'test_comment2'], + total_usage=10, + unique_usage=5, + tags=['test_tag1', 'test_tag2'], + badges=['badge1'], + schema_description='schema_description', + programmatic_descriptions=['test'])] * 5 + + for d in data: + loader.load(d) + loader.close() + + expected = [ + ('{"key": "test_table_key", "column_descriptions": ["test_comment1", "test_comment2"], ' + '"schema": "test_schema", "database": "test_database", "cluster": "test_cluster", ' + '"column_names": ["test_col1", "test_col2"], "name": "test_table", ' + '"last_updated_timestamp": 123456789, "display_name": "test_schema.test_table", ' + '"description": "test_description", "unique_usage": 5, "total_usage": 10, ' + '"tags": ["test_tag1", "test_tag2"], "schema_description": "schema_description", ' + '"programmatic_descriptions":["test"], ' + '"badges": ["badge1"]}') + ] * 5 + + self._check_results_helper(expected=expected) diff --git a/databuilder/tests/unit/loader/test_file_system_mysql_csv_loader.py b/databuilder/tests/unit/loader/test_file_system_mysql_csv_loader.py new file mode 100644 index 0000000000..0ef12dd19f --- /dev/null +++ b/databuilder/tests/unit/loader/test_file_system_mysql_csv_loader.py @@ -0,0 +1,70 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest +from collections import OrderedDict +from csv import DictReader +from os import listdir +from os.path import ( + basename, isfile, join, splitext, +) +from typing import Any, Dict + +from pyhocon import ConfigFactory + +from databuilder.job.base_job import Job +from databuilder.loader.file_system_mysql_csv_loader import FSMySQLCSVLoader +from tests.unit.models.test_table_serializable import Actor, Movie + + +class TestFileSystemMySQLCSVLoader(unittest.TestCase): + def setUp(self) -> None: + directory = '/var/tmp/TestFileSystemMySQLCSVLoader' + self._conf = ConfigFactory.from_dict( + { + FSMySQLCSVLoader.RECORD_DIR_PATH: '{}/{}'.format(directory, 'records'), + FSMySQLCSVLoader.SHOULD_DELETE_CREATED_DIR: True, + FSMySQLCSVLoader.FORCE_CREATE_DIR: True, + } + ) + + def tearDown(self) -> None: + Job.closer.close() + + def test_load(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + movie = Movie('Top Gun', actors) + + loader = FSMySQLCSVLoader() + loader.init(self._conf) + loader.load(movie) + + loader.close() + + expected_record_path = '{}/../resources/fs_mysql_csv_loader/records'.format( + os.path.join(os.path.dirname(__file__)) + ) + expected_records = self._get_csv_rows(expected_record_path) + actual_records = self._get_csv_rows(self._conf.get_string(FSMySQLCSVLoader.RECORD_DIR_PATH)) + + self.maxDiff = None + self.assertDictEqual(expected_records, actual_records) + + def _get_csv_rows(self, path: str) -> Dict[str, Any]: + files = [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + result: Dict[str, Any] = {} + for f in files: + filename = splitext(basename(f))[0] + result[filename] = [] + with open(f, 'r') as f_input: + reader = DictReader(f_input) + for row in reader: + result[filename].append(OrderedDict(sorted(row.items()))) + + return result + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/loader/test_file_system_neptune_csv_loader.py b/databuilder/tests/unit/loader/test_file_system_neptune_csv_loader.py new file mode 100644 index 0000000000..afa170e5e2 --- /dev/null +++ b/databuilder/tests/unit/loader/test_file_system_neptune_csv_loader.py @@ -0,0 +1,94 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest +from collections import OrderedDict +from csv import DictReader +from operator import itemgetter +from os import listdir +from os.path import isfile, join +from typing import ( + Any, Callable, Dict, Iterable, +) + +from freezegun import freeze_time +from pyhocon import ConfigFactory + +from databuilder.job.base_job import Job +from databuilder.loader.file_system_neptune_csv_loader import FSNeptuneCSVLoader +from tests.unit.models.test_graph_serializable import ( + Actor, City, Movie, +) + + +class FileSystemNeptuneCSVLoaderTest(unittest.TestCase): + def setUp(self) -> None: + prefix = '/var/tmp/TestFileSystemNeptuneCSVLoader' + self._conf = ConfigFactory.from_dict( + { + FSNeptuneCSVLoader.NODE_DIR_PATH: '{}/{}'.format(prefix, 'nodes'), + FSNeptuneCSVLoader.RELATION_DIR_PATH: '{}/{}'.format(prefix, 'relationships'), + FSNeptuneCSVLoader.SHOULD_DELETE_CREATED_DIR: True, + FSNeptuneCSVLoader.FORCE_CREATE_DIR: True, + FSNeptuneCSVLoader.JOB_PUBLISHER_TAG: 'TESTED' + } + ) + + def tearDown(self) -> None: + Job.closer.close() + + @freeze_time("2020-09-01 01:01:00") + def test_load(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + cities = [City('San Diego'), City('Oakland')] + movie = Movie('Top Gun', actors, cities) + + loader = FSNeptuneCSVLoader() + loader.init(self._conf) + loader.load(movie) + + loader.close() + + expected_node_path = '{}/../resources/fs_neptune_csv_loader/nodes'.format( + os.path.join(os.path.dirname(__file__)) + ) + expected_nodes = self._get_csv_rows( + expected_node_path, + itemgetter('~id') + ) + actual_nodes = self._get_csv_rows( + self._conf.get_string(FSNeptuneCSVLoader.NODE_DIR_PATH), + itemgetter('~id') + ) + self.maxDiff = None + self.assertEqual(expected_nodes, actual_nodes) + + expected_rel_path = '{}/../resources/fs_neptune_csv_loader/relationships'.format( + os.path.join(os.path.dirname(__file__)) + ) + expected_relations = self._get_csv_rows( + expected_rel_path, + itemgetter('~id') + ) + actual_relations = self._get_csv_rows( + self._conf.get_string(FSNeptuneCSVLoader.RELATION_DIR_PATH), + itemgetter('~id') + ) + self.assertListEqual(list(expected_relations), list(actual_relations)) + + def _get_csv_rows(self, path: str, sorting_key_getter: Callable) -> Iterable[Dict[str, Any]]: + files = [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + result = [] + for f in files: + with open(f, 'r') as f_input: + reader = DictReader(f_input) + for row in reader: + result.append(OrderedDict(sorted(row.items()))) + + return sorted(result, key=sorting_key_getter) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/loader/test_fs_neo4j_csv_loader.py b/databuilder/tests/unit/loader/test_fs_neo4j_csv_loader.py new file mode 100644 index 0000000000..c37f2cd720 --- /dev/null +++ b/databuilder/tests/unit/loader/test_fs_neo4j_csv_loader.py @@ -0,0 +1,151 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import collections +import csv +import logging +import os +import unittest +from operator import itemgetter +from os import listdir +from os.path import isfile, join +from typing import ( + Any, Callable, Dict, Iterable, Optional, Union, +) + +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.job.base_job import Job +from databuilder.loader.file_system_neo4j_csv_loader import FsNeo4jCSVLoader +from databuilder.models.graph_serializable import ( + GraphNode, GraphRelationship, GraphSerializable, +) +from tests.unit.models.test_graph_serializable import ( + Actor, City, Movie, +) + +here = os.path.dirname(__file__) + + +class TestFsNeo4jCSVLoader(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + def tearDown(self) -> None: + Job.closer.close() + + def test_load(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + cities = [City('San Diego'), City('Oakland')] + movie = Movie('Top Gun', actors, cities) + + loader = FsNeo4jCSVLoader() + + folder = 'movies' + conf = self._make_conf(folder) + + loader.init(conf) + loader.load(movie) + loader.close() + + expected_node_path = os.path.join(here, f'../resources/fs_neo4j_csv_loader/{folder}/nodes') + expected_nodes = self._get_csv_rows(expected_node_path, itemgetter('KEY')) + actual_nodes = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH), + itemgetter('KEY')) + self.assertEqual(expected_nodes, actual_nodes) + + expected_rel_path = os.path.join(here, f'../resources/fs_neo4j_csv_loader/{folder}/relationships') + expected_relations = self._get_csv_rows(expected_rel_path, itemgetter('START_KEY', 'END_KEY')) + actual_relations = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.RELATION_DIR_PATH), + itemgetter('START_KEY', 'END_KEY')) + self.assertEqual(expected_relations, actual_relations) + + def test_load_disjoint_properties(self) -> None: + people = [ + Person("Taylor", job="Engineer"), + Person("Griffin", pet="Lion"), + ] + + loader = FsNeo4jCSVLoader() + + folder = 'people' + conf = self._make_conf(folder) + + loader.init(conf) + loader.load(people[0]) + loader.load(people[1]) + loader.close() + + expected_node_path = os.path.join(here, f'../resources/fs_neo4j_csv_loader/{folder}/nodes') + expected_nodes = self._get_csv_rows(expected_node_path, itemgetter('KEY')) + actual_nodes = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH), + itemgetter('KEY')) + self.assertEqual(expected_nodes, actual_nodes) + + def _make_conf(self, test_name: str) -> ConfigTree: + prefix = '/var/tmp/TestFsNeo4jCSVLoader' + + return ConfigFactory.from_dict({ + FsNeo4jCSVLoader.NODE_DIR_PATH: f'{prefix}/{test_name}/{"nodes"}', + FsNeo4jCSVLoader.RELATION_DIR_PATH: f'{prefix}/{test_name}/{"relationships"}', + FsNeo4jCSVLoader.SHOULD_DELETE_CREATED_DIR: True + }) + + def _get_csv_rows(self, + path: str, + sorting_key_getter: Callable) -> Iterable[Dict[str, Any]]: + files = [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + result = [] + for f in files: + with open(f, 'r') as f_input: + reader = csv.DictReader(f_input) + for row in reader: + result.append(collections.OrderedDict(sorted(row.items()))) + + return sorted(result, key=sorting_key_getter) + + +class Person(GraphSerializable): + """ A Person has multiple optional attributes. When an attribute is None, + it is not included in the resulting node. + """ + LABEL = 'Person' + KEY_FORMAT = 'person://{}' + + def __init__(self, + name: str, + *, + pet: Optional[str] = None, + job: Optional[str] = None, + ) -> None: + self._name = name + self._pet = pet + self._job = job + self._node_iter = iter(self.create_nodes()) + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + return None + + def create_nodes(self) -> Iterable[GraphNode]: + attributes = {"name": self._name} + if self._pet: + attributes['pet'] = self._pet + if self._job: + attributes['job'] = self._job + + return [GraphNode( + key=Person.KEY_FORMAT.format(self._name), + label=Person.LABEL, + attributes=attributes + )] + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/loader/test_generic_loader.py b/databuilder/tests/unit/loader/test_generic_loader.py new file mode 100644 index 0000000000..b44f4219f2 --- /dev/null +++ b/databuilder/tests/unit/loader/test_generic_loader.py @@ -0,0 +1,38 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import MagicMock +from pyhocon import ConfigFactory + +from databuilder.loader.generic_loader import CALLBACK_FUNCTION, GenericLoader + + +class TestGenericLoader(unittest.TestCase): + + def test_loading(self) -> None: + + loader = GenericLoader() + callback_func = MagicMock() + loader.init(conf=ConfigFactory.from_dict({ + CALLBACK_FUNCTION: callback_func + })) + + loader.load({'foo': 'bar'}) + loader.close() + + callback_func.assert_called_once() + + def test_none_loading(self) -> None: + + loader = GenericLoader() + callback_func = MagicMock() + loader.init(conf=ConfigFactory.from_dict({ + CALLBACK_FUNCTION: callback_func + })) + + loader.load(None) + loader.close() + + callback_func.assert_not_called() diff --git a/databuilder/tests/unit/models/__init__.py b/databuilder/tests/unit/models/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/models/dashboard/__init__.py b/databuilder/tests/unit/models/dashboard/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_chart.py b/databuilder/tests/unit/models/dashboard/test_dashboard_chart.py new file mode 100644 index 0000000000..670fe8c604 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_chart.py @@ -0,0 +1,229 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any, Dict +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_chart import DashboardChart +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.serializers import ( + atlas_serializer, mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardChart(unittest.TestCase): + + def test_create_nodes(self) -> None: + dashboard_chart = DashboardChart(dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_name='c_name', + chart_type='bar', + chart_url='http://gold.foo/chart' + ) + + actual = dashboard_chart.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) + neptune_serialized = neptune_serializer.convert_node(actual) + expected: Dict[str, Any] = { + 'name': 'c_name', + 'type': 'bar', + 'id': 'c_id', + 'url': 'http://gold.foo/chart', + 'KEY': '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + 'LABEL': 'Chart' + } + neptune_expected = { + '~id': 'Chart:_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + '~label': 'Chart', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'type:String(single)': 'bar', + 'name:String(single)': 'c_name', + 'id:String(single)': 'c_id', + 'url:String(single)': 'http://gold.foo/chart', + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertDictEqual(neptune_expected, neptune_serialized) + self.assertIsNone(dashboard_chart.create_next_node()) + + dashboard_chart = DashboardChart( + dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_url='http://gold.foo.bar/' + ) + + actual2 = dashboard_chart.create_next_node() + actual2_serialized = neo4_serializer.serialize_node(actual2) + actual2_neptune_serialized = neptune_serializer.convert_node(actual2) + expected2: Dict[str, Any] = { + 'id': 'c_id', + 'KEY': '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + 'LABEL': 'Chart', + 'url': 'http://gold.foo.bar/' + } + neptune_expected2 = { + '~id': 'Chart:_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + '~label': 'Chart', + 'id:String(single)': 'c_id', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'url:String(single)': 'http://gold.foo.bar/', + } + assert actual2 is not None + self.assertDictEqual(expected2, actual2_serialized) + self.assertDictEqual(neptune_expected2, actual2_neptune_serialized) + + def test_create_relation(self) -> None: + dashboard_chart = DashboardChart(dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_name='c_name', + chart_type='bar', + ) + + actual = dashboard_chart.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + actual_neptune_serialized = neptune_serializer.convert_relationship(actual) + start_key = '_dashboard://gold.dg_id/d_id/query/q_id' + end_key = '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id' + expected: Dict[str, Any] = { + RELATION_END_KEY: end_key, + RELATION_START_LABEL: 'Query', + RELATION_END_LABEL: 'Chart', + RELATION_START_KEY: start_key, + RELATION_TYPE: 'HAS_CHART', + RELATION_REVERSE_TYPE: 'CHART_OF' + } + + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Query:" + start_key, + to_vertex_id="Chart:" + end_key, + label='HAS_CHART' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Query:" + start_key, + to_vertex_id="Chart:" + end_key, + label='HAS_CHART' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Query:" + start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Chart:" + end_key, + NEPTUNE_HEADER_LABEL: 'HAS_CHART', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Chart:" + end_key, + to_vertex_id="Query:" + start_key, + label='CHART_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Chart:" + end_key, + to_vertex_id="Query:" + start_key, + label='CHART_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Chart:" + end_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Query:" + start_key, + NEPTUNE_HEADER_LABEL: 'CHART_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + assert actual is not None + self.assertEqual(expected, actual_serialized) + self.assertEqual(neptune_forward_expected, actual_neptune_serialized[0]) + self.assertEqual(neptune_reversed_expected, actual_neptune_serialized[1]) + self.assertIsNone(dashboard_chart.create_next_relation()) + + def test_create_records(self) -> None: + dashboard_chart = DashboardChart(dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_name='c_name', + chart_type='bar', + chart_url='http://gold.foo/chart' + ) + + actual = dashboard_chart.create_next_record() + actual_serialized = mysql_serializer.serialize_record(actual) + expected = { + 'rk': '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + 'id': 'c_id', + 'query_rk': '_dashboard://gold.dg_id/d_id/query/q_id', + 'name': 'c_name', + 'type': 'bar', + 'url': 'http://gold.foo/chart' + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(dashboard_chart.create_next_record()) + + dashboard_chart = DashboardChart(dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_url='http://gold.foo.bar/' + ) + + actual2 = dashboard_chart.create_next_record() + actual2_serialized = mysql_serializer.serialize_record(actual2) + expected2 = { + 'rk': '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + 'id': 'c_id', + 'query_rk': '_dashboard://gold.dg_id/d_id/query/q_id', + 'url': 'http://gold.foo.bar/' + } + + assert actual2 is not None + self.assertDictEqual(expected2, actual2_serialized) + + def test_create_atlas_entity(self) -> None: + dashboard_chart = DashboardChart(dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + chart_id='c_id', + chart_name='c_name', + chart_type='bar', + chart_url='http://gold.foo/chart', + product='superset' + ) + + actual = dashboard_chart.create_next_atlas_entity() + actual_serialized = atlas_serializer.serialize_entity(actual) + expected = { + 'typeName': 'DashboardChart', + 'operation': 'CREATE', + 'relationships': 'query#DashboardQuery#superset_dashboard://gold.dg_id/d_id/query/q_id', + 'qualifiedName': 'superset_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', + 'name': 'c_name', + 'type': 'bar', + 'url': 'http://gold.foo/chart' + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(dashboard_chart.create_next_atlas_entity()) diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_last_modified.py b/databuilder/tests/unit/models/dashboard/test_dashboard_last_modified.py new file mode 100644 index 0000000000..940fa434ca --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_last_modified.py @@ -0,0 +1,165 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any, Dict +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.serializers import ( + atlas_serializer, mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardLastModifiedTimestamp(unittest.TestCase): + + def setUp(self) -> None: + self.dashboard_last_modified = DashboardLastModifiedTimestamp( + last_modified_timestamp=123456789, + cluster='cluster_id', + product='product_id', + dashboard_id='dashboard_id', + dashboard_group_id='dashboard_group_id' + ) + + self.expected_ts_key = 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id/' \ + '_last_modified_timestamp' + self.expected_dashboard_key = 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id' + + def test_dashboard_timestamp_nodes(self) -> None: + + actual = self.dashboard_last_modified.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) + + expected: Dict[str, Any] = { + 'timestamp:UNQUOTED': 123456789, + 'name': 'last_updated_timestamp', + 'KEY': self.expected_ts_key, + 'LABEL': 'Timestamp' + } + + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + + self.assertIsNone(self.dashboard_last_modified.create_next_node()) + + def test_neptune_dashboard_timestamp_nodes(self) -> None: + actual = self.dashboard_last_modified.create_next_node() + actual_neptune_serialized = neptune_serializer.convert_node(actual) + neptune_expected = { + NEPTUNE_HEADER_ID: 'Timestamp:' + self.expected_ts_key, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: self.expected_ts_key, + NEPTUNE_HEADER_LABEL: 'Timestamp', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'last_updated_timestamp', + 'timestamp:Long(single)': 123456789, + } + + self.assertDictEqual(actual_neptune_serialized, neptune_expected) + + def test_dashboard_owner_relations(self) -> None: + + actual = self.dashboard_last_modified.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + + expected: Dict[str, Any] = { + RELATION_END_KEY: self.expected_ts_key, + RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: 'Timestamp', + RELATION_START_KEY: self.expected_dashboard_key, + RELATION_TYPE: 'LAST_UPDATED_AT', + RELATION_REVERSE_TYPE: 'LAST_UPDATED_TIME_OF' + } + + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + self.assertIsNone(self.dashboard_last_modified.create_next_relation()) + + def test_dashboard_owner_relations_neptune(self) -> None: + actual = self.dashboard_last_modified.create_next_relation() + actual_serialized = neptune_serializer.convert_relationship(actual) + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:' + self.expected_dashboard_key, + to_vertex_id='Timestamp:' + self.expected_ts_key, + label='LAST_UPDATED_AT' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:' + self.expected_dashboard_key, + to_vertex_id='Timestamp:' + self.expected_ts_key, + label='LAST_UPDATED_AT' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:' + self.expected_dashboard_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Timestamp:' + self.expected_ts_key, + NEPTUNE_HEADER_LABEL: 'LAST_UPDATED_AT', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Timestamp:' + self.expected_ts_key, + to_vertex_id='Dashboard:' + self.expected_dashboard_key, + label='LAST_UPDATED_TIME_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Timestamp:' + self.expected_ts_key, + to_vertex_id='Dashboard:' + self.expected_dashboard_key, + label='LAST_UPDATED_TIME_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Timestamp:' + self.expected_ts_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:' + self.expected_dashboard_key, + NEPTUNE_HEADER_LABEL: 'LAST_UPDATED_TIME_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + self.maxDiff = None + assert actual is not None + self.assertDictEqual(actual_serialized[0], neptune_forward_expected) + self.assertDictEqual(actual_serialized[1], neptune_reversed_expected) + self.assertIsNone(self.dashboard_last_modified.create_next_relation()) + + def test_dashboard_timestamp_records(self) -> None: + + actual = self.dashboard_last_modified.create_next_record() + actual_serialized = mysql_serializer.serialize_record(actual) + + expected = { + 'rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id/_last_modified_timestamp', + 'timestamp': 123456789, + 'name': 'last_updated_timestamp', + 'dashboard_rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id' + } + + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + self.assertIsNone(self.dashboard_last_modified.create_next_record()) + + def test_dashboard_last_modified_relation_atlas(self) -> None: + + actual = self.dashboard_last_modified.create_next_atlas_entity() + actual_serialized = atlas_serializer.serialize_entity(actual) + + expected = { + "typeName": "Dashboard", + "operation": "UPDATE", + "relationships": None, + "qualifiedName": "product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id", + "lastModifiedTimestamp": 123456789 + } + + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + self.assertIsNone(self.dashboard_last_modified.create_next_atlas_entity()) diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_metadata.py b/databuilder/tests/unit/models/dashboard/test_dashboard_metadata.py new file mode 100644 index 0000000000..a52bd5d589 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_metadata.py @@ -0,0 +1,616 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +import unittest +from typing import Dict, List +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.serializers import ( + atlas_serializer, mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardMetadata(unittest.TestCase): + def setUp(self) -> None: + self.full_dashboard_metadata = DashboardMetadata( + 'Product - Jobs.cz', + 'Agent', + 'Agent dashboard description', + ['test_tag', 'tag2'], + dashboard_group_description='foo dashboard group description', + created_timestamp=123456789, + dashboard_group_url='https://foo.bar/dashboard_group/foo', + dashboard_url='https://foo.bar/dashboard_group/foo/dashboard/bar' + ) + + # Without tags + self.dashboard_metadata2 = DashboardMetadata('Product - Atmoskop', + 'Atmoskop', + 'Atmoskop dashboard description', + [], + ) + + # One common tag with dashboard_metadata, no description + self.dashboard_metadata3 = DashboardMetadata('Product - Jobs.cz', + 'Dohazovac', + '', + ['test_tag', 'tag3'] + ) + + # Necessary minimum -- NOT USED + self.dashboard_metadata4 = DashboardMetadata('', + 'PzR', + '', + [] + ) + + self.expected_nodes_deduped = [ + { + 'KEY': '_dashboard://gold', + 'LABEL': 'Cluster', 'name': 'gold' + }, + { + 'created_timestamp:UNQUOTED': 123456789, + 'name': 'Agent', + 'KEY': '_dashboard://gold.Product - Jobs.cz/Agent', + 'LABEL': 'Dashboard', + 'dashboard_url': 'https://foo.bar/dashboard_group/foo/dashboard/bar' + }, + {'name': 'Product - Jobs.cz', 'KEY': '_dashboard://gold.Product - Jobs.cz', 'LABEL': 'Dashboardgroup', + 'dashboard_group_url': 'https://foo.bar/dashboard_group/foo'}, + {'KEY': '_dashboard://gold.Product - Jobs.cz/_description', 'LABEL': 'Description', + 'description': 'foo dashboard group description'}, + {'description': 'Agent dashboard description', + 'KEY': '_dashboard://gold.Product - Jobs.cz/Agent/_description', 'LABEL': 'Description'}, + {'tag_type': 'dashboard', 'KEY': 'test_tag', 'LABEL': 'Tag'}, + {'tag_type': 'dashboard', 'KEY': 'tag2', 'LABEL': 'Tag'} + ] + + self.expected_nodes = copy.deepcopy(self.expected_nodes_deduped) + + self.expected_rels_deduped = [ + {'END_KEY': '_dashboard://gold.Product - Jobs.cz', 'END_LABEL': 'Dashboardgroup', + 'REVERSE_TYPE': 'DASHBOARD_GROUP_OF', 'START_KEY': '_dashboard://gold', + 'START_LABEL': 'Cluster', 'TYPE': 'DASHBOARD_GROUP'}, + {'END_KEY': '_dashboard://gold.Product - Jobs.cz/_description', 'END_LABEL': 'Description', + 'REVERSE_TYPE': 'DESCRIPTION_OF', 'START_KEY': '_dashboard://gold.Product - Jobs.cz', + 'START_LABEL': 'Dashboardgroup', 'TYPE': 'DESCRIPTION'}, + {'END_KEY': '_dashboard://gold.Product - Jobs.cz', 'START_LABEL': 'Dashboard', + 'END_LABEL': 'Dashboardgroup', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Agent', 'TYPE': 'DASHBOARD_OF', + 'REVERSE_TYPE': 'DASHBOARD'}, + {'END_KEY': '_dashboard://gold.Product - Jobs.cz/Agent/_description', 'START_LABEL': 'Dashboard', + 'END_LABEL': 'Description', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Agent', 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'test_tag', 'START_LABEL': 'Dashboard', 'END_LABEL': 'Tag', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Agent', 'TYPE': 'TAG', 'REVERSE_TYPE': 'TAG_OF'}, + {'END_KEY': 'tag2', 'START_LABEL': 'Dashboard', 'END_LABEL': 'Tag', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Agent', 'TYPE': 'TAG', 'REVERSE_TYPE': 'TAG_OF'} + ] + + self.expected_rels = copy.deepcopy(self.expected_rels_deduped) + + self.expected_nodes_deduped2 = [ + {'KEY': '_dashboard://gold', 'LABEL': 'Cluster', 'name': 'gold'}, + {'name': 'Atmoskop', 'KEY': '_dashboard://gold.Product - Atmoskop/Atmoskop', 'LABEL': 'Dashboard'}, + {'name': 'Product - Atmoskop', 'KEY': '_dashboard://gold.Product - Atmoskop', 'LABEL': 'Dashboardgroup'}, + {'description': 'Atmoskop dashboard description', + 'KEY': '_dashboard://gold.Product - Atmoskop/Atmoskop/_description', + 'LABEL': 'Description'}, + ] + + self.expected_nodes2 = copy.deepcopy(self.expected_nodes_deduped2) + + self.expected_rels_deduped2 = [ + {'END_KEY': '_dashboard://gold.Product - Atmoskop', 'END_LABEL': 'Dashboardgroup', + 'REVERSE_TYPE': 'DASHBOARD_GROUP_OF', 'START_KEY': '_dashboard://gold', + 'START_LABEL': 'Cluster', 'TYPE': 'DASHBOARD_GROUP'}, + {'END_KEY': '_dashboard://gold.Product - Atmoskop', 'START_LABEL': 'Dashboard', + 'END_LABEL': 'Dashboardgroup', + 'START_KEY': '_dashboard://gold.Product - Atmoskop/Atmoskop', 'TYPE': 'DASHBOARD_OF', + 'REVERSE_TYPE': 'DASHBOARD'}, + {'END_KEY': '_dashboard://gold.Product - Atmoskop/Atmoskop/_description', 'START_LABEL': 'Dashboard', + 'END_LABEL': 'Description', + 'START_KEY': '_dashboard://gold.Product - Atmoskop/Atmoskop', 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + ] + + self.expected_rels2 = copy.deepcopy(self.expected_rels_deduped2) + + self.expected_nodes_deduped3 = [ + {'KEY': '_dashboard://gold', 'LABEL': 'Cluster', 'name': 'gold'}, + {'name': 'Dohazovac', 'KEY': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'LABEL': 'Dashboard'}, + {'name': 'Product - Jobs.cz', 'KEY': '_dashboard://gold.Product - Jobs.cz', 'LABEL': 'Dashboardgroup'}, + {'tag_type': 'dashboard', 'KEY': 'test_tag', 'LABEL': 'Tag'}, + {'tag_type': 'dashboard', 'KEY': 'tag3', 'LABEL': 'Tag'} + ] + + self.expected_nodes3 = copy.deepcopy(self.expected_nodes_deduped3) + + self.expected_rels_deduped3 = [ + {'END_KEY': '_dashboard://gold.Product - Jobs.cz', 'END_LABEL': 'Dashboardgroup', + 'REVERSE_TYPE': 'DASHBOARD_GROUP_OF', 'START_KEY': '_dashboard://gold', + 'START_LABEL': 'Cluster', 'TYPE': 'DASHBOARD_GROUP'}, + {'END_KEY': '_dashboard://gold.Product - Jobs.cz', 'START_LABEL': 'Dashboard', + 'END_LABEL': 'Dashboardgroup', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'TYPE': 'DASHBOARD_OF', + 'REVERSE_TYPE': 'DASHBOARD'}, + {'END_KEY': 'test_tag', 'START_LABEL': 'Dashboard', 'END_LABEL': 'Tag', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'TYPE': 'TAG', 'REVERSE_TYPE': 'TAG_OF'}, + {'END_KEY': 'tag3', 'START_LABEL': 'Dashboard', 'END_LABEL': 'Tag', + 'START_KEY': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'TYPE': 'TAG', 'REVERSE_TYPE': 'TAG_OF'}, + ] + + self.expected_rels3 = copy.deepcopy(self.expected_rels_deduped3) + + def test_full_example(self) -> None: + node_row = self.full_dashboard_metadata.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.full_dashboard_metadata.next_node() + + self.assertEqual(self.expected_nodes, actual) + + relation_row = self.full_dashboard_metadata.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.full_dashboard_metadata.next_relation() + + self.assertEqual(self.expected_rels, actual) + + def test_full_dashboard_example_neptune(self) -> None: + expected_neptune_rels = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Cluster:_dashboard://gold', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DASHBOARD_GROUP' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Cluster:_dashboard://gold', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DASHBOARD_GROUP' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Cluster:_dashboard://gold', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_HEADER_LABEL: 'DASHBOARD_GROUP', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Cluster:_dashboard://gold', + label='DASHBOARD_GROUP_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Cluster:_dashboard://gold', + label='DASHBOARD_GROUP_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Cluster:_dashboard://gold', + NEPTUNE_HEADER_LABEL: 'DASHBOARD_GROUP_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:_dashboard://gold.Product - Jobs.cz/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/_description', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/_description', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:_dashboard://gold.Product - Jobs.cz/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DASHBOARD_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + label='DASHBOARD_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_HEADER_LABEL: 'DASHBOARD_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='DASHBOARD' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='DASHBOARD' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_HEADER_LABEL: 'DASHBOARD', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Tag:test_tag', + label='TAG' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Tag:test_tag', + label='TAG' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Tag:test_tag', + NEPTUNE_HEADER_LABEL: 'TAG', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Tag:test_tag', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='TAG_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Tag:test_tag', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='TAG_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Tag:test_tag', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_HEADER_LABEL: 'TAG_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Tag:tag2', + label='TAG' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + to_vertex_id='Tag:tag2', + label='TAG' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Tag:tag2', + NEPTUNE_HEADER_LABEL: 'TAG', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Tag:tag2', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='TAG_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Tag:tag2', + to_vertex_id='Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + label='TAG_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Tag:tag2', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_HEADER_LABEL: 'TAG_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + ] + + expected_neptune_nodes = [ + { + NEPTUNE_HEADER_ID: 'Cluster:_dashboard://gold', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold', + NEPTUNE_HEADER_LABEL: 'Cluster', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'gold' + }, + { + NEPTUNE_HEADER_ID: 'Dashboard:_dashboard://gold.Product - Jobs.cz/Agent', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.Product - Jobs.cz/Agent', + NEPTUNE_HEADER_LABEL: 'Dashboard', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'Agent', + 'dashboard_url:String(single)': 'https://foo.bar/dashboard_group/foo/dashboard/bar', + 'created_timestamp:Long(single)': 123456789, + }, + { + NEPTUNE_HEADER_ID: 'Dashboardgroup:_dashboard://gold.Product - Jobs.cz', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.Product - Jobs.cz', + NEPTUNE_HEADER_LABEL: 'Dashboardgroup', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'Product - Jobs.cz', + 'dashboard_group_url:String(single)': 'https://foo.bar/dashboard_group/foo' + }, + { + NEPTUNE_HEADER_ID: 'Description:_dashboard://gold.Product - Jobs.cz/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.Product - Jobs.cz/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description:String(single)': 'foo dashboard group description', + }, + { + NEPTUNE_HEADER_ID: 'Description:_dashboard://gold.Product - Jobs.cz/Agent/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.Product - Jobs.cz/Agent/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description:String(single)': 'Agent dashboard description' + }, + { + NEPTUNE_HEADER_ID: 'Tag:test_tag', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'test_tag', + NEPTUNE_HEADER_LABEL: 'Tag', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'tag_type:String(single)': 'dashboard' + }, + { + NEPTUNE_HEADER_ID: 'Tag:tag2', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'tag2', + NEPTUNE_HEADER_LABEL: 'Tag', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'tag_type:String(single)': 'dashboard' + }, + ] + self.maxDiff = None + node_row = self.full_dashboard_metadata.next_node() + actual = [] + while node_row: + node_serialized = neptune_serializer.convert_node(node_row) + actual.append(node_serialized) + node_row = self.full_dashboard_metadata.next_node() + + self.assertEqual(expected_neptune_nodes, actual) + + relation_row = self.full_dashboard_metadata.next_relation() + neptune_actual: List[List[Dict]] = [] + while relation_row: + relation_serialized = neptune_serializer.convert_relationship(relation_row) + neptune_actual.append(relation_serialized) + relation_row = self.full_dashboard_metadata.next_relation() + + self.assertEqual(expected_neptune_rels, neptune_actual) + + def test_dashboard_without_tags(self) -> None: + node_row = self.dashboard_metadata2.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.dashboard_metadata2.next_node() + + self.assertEqual(self.expected_nodes_deduped2, actual) + + relation_row = self.dashboard_metadata2.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.dashboard_metadata2.next_relation() + + self.assertEqual(self.expected_rels_deduped2, actual) + + def test_dashboard_no_description(self) -> None: + node_row = self.dashboard_metadata3.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.dashboard_metadata3.next_node() + + self.assertEqual(self.expected_nodes_deduped3, actual) + + relation_row = self.dashboard_metadata3.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.dashboard_metadata3.next_relation() + + self.assertEqual(self.expected_rels_deduped3, actual) + + def test_dashboard_record_full_example(self) -> None: + expected_records = [ + {'rk': '_dashboard://gold', 'name': 'gold'}, + {'rk': '_dashboard://gold.Product - Jobs.cz', 'name': 'Product - Jobs.cz', + 'cluster_rk': '_dashboard://gold', + 'dashboard_group_url': 'https://foo.bar/dashboard_group/foo'}, + {'rk': '_dashboard://gold.Product - Jobs.cz/_description', 'description': 'foo dashboard group description', + 'dashboard_group_rk': '_dashboard://gold.Product - Jobs.cz'}, + {'rk': '_dashboard://gold.Product - Jobs.cz/Agent', 'name': 'Agent', + 'dashboard_group_rk': '_dashboard://gold.Product - Jobs.cz', 'created_timestamp': 123456789, + 'dashboard_url': 'https://foo.bar/dashboard_group/foo/dashboard/bar'}, + {'rk': '_dashboard://gold.Product - Jobs.cz/Agent/_description', + 'description': 'Agent dashboard description', + 'dashboard_rk': '_dashboard://gold.Product - Jobs.cz/Agent'}, + {'rk': 'test_tag', 'tag_type': 'dashboard'}, + {'dashboard_rk': '_dashboard://gold.Product - Jobs.cz/Agent', 'tag_rk': 'test_tag'}, + {'rk': 'tag2', 'tag_type': 'dashboard'}, + {'dashboard_rk': '_dashboard://gold.Product - Jobs.cz/Agent', 'tag_rk': 'tag2'} + ] + record = self.full_dashboard_metadata.next_record() + actual = [] + while record: + record_serialized = mysql_serializer.serialize_record(record) + actual.append(record_serialized) + record = self.full_dashboard_metadata.next_record() + + self.assertEqual(expected_records, actual) + + def test_dashboard_record_without_tags(self) -> None: + expected_records_without_tags = [ + {'rk': '_dashboard://gold', 'name': 'gold'}, + {'rk': '_dashboard://gold.Product - Atmoskop', 'name': 'Product - Atmoskop', + 'cluster_rk': '_dashboard://gold'}, + {'rk': '_dashboard://gold.Product - Atmoskop/Atmoskop', 'name': 'Atmoskop', + 'dashboard_group_rk': '_dashboard://gold.Product - Atmoskop'}, + {'rk': '_dashboard://gold.Product - Atmoskop/Atmoskop/_description', + 'description': 'Atmoskop dashboard description', + 'dashboard_rk': '_dashboard://gold.Product - Atmoskop/Atmoskop'} + ] + record = self.dashboard_metadata2.next_record() + actual = [] + while record: + record_serialized = mysql_serializer.serialize_record(record) + actual.append(record_serialized) + record = self.dashboard_metadata2.next_record() + + self.assertEqual(expected_records_without_tags, actual) + + def test_dashboard_record_no_description(self) -> None: + expected_records_without_description = [ + {'rk': '_dashboard://gold', 'name': 'gold'}, + {'rk': '_dashboard://gold.Product - Jobs.cz', 'name': 'Product - Jobs.cz', + 'cluster_rk': '_dashboard://gold'}, + {'rk': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'name': 'Dohazovac', + 'dashboard_group_rk': '_dashboard://gold.Product - Jobs.cz'}, + {'rk': 'test_tag', 'tag_type': 'dashboard'}, + {'dashboard_rk': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'tag_rk': 'test_tag'}, + {'rk': 'tag3', 'tag_type': 'dashboard'}, + {'dashboard_rk': '_dashboard://gold.Product - Jobs.cz/Dohazovac', 'tag_rk': 'tag3'} + ] + record = self.dashboard_metadata3.next_record() + actual = [] + while record: + record_serialized = mysql_serializer.serialize_record(record) + actual.append(record_serialized) + record = self.dashboard_metadata3.next_record() + + self.assertEqual(expected_records_without_description, actual) + + def test_full_dashboard_example_atlas(self) -> None: + + expected = [{'description': 'foo dashboard group description', + 'id': 'Product - Jobs.cz', + 'name': 'Product - Jobs.cz', + 'operation': 'CREATE', + 'qualifiedName': '_dashboard://gold.Product - Jobs.cz', + 'relationships': None, + 'typeName': 'DashboardGroup', + 'url': 'https://foo.bar/dashboard_group/foo'}, + {'cluster': 'gold', + 'createdTimestamp': 123456789, + 'description': 'Agent dashboard description', + 'name': 'Agent', + 'operation': 'CREATE', + 'product': '', + 'qualifiedName': '_dashboard://gold.Product - Jobs.cz/Agent', + 'relationships': 'group#DashboardGroup#_dashboard://gold.Product - Jobs.cz', + 'typeName': 'Dashboard', + 'url': 'https://foo.bar/dashboard_group/foo/dashboard/bar'}] + + entity = self.full_dashboard_metadata.next_atlas_entity() + actual = [] + while entity: + record_serialized = atlas_serializer.serialize_entity(entity) + actual.append(record_serialized) + entity = self.full_dashboard_metadata.next_atlas_entity() + + self.assertEqual(expected, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_owner.py b/databuilder/tests/unit/models/dashboard/test_dashboard_owner.py new file mode 100644 index 0000000000..b7cc27c0e8 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_owner.py @@ -0,0 +1,109 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_owner import DashboardOwner +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardOwner(unittest.TestCase): + + def setUp(self) -> None: + self.dashboard_owner = DashboardOwner( + email='foo@bar.com', + cluster='cluster_id', + product='product_id', + dashboard_id='dashboard_id', + dashboard_group_id='dashboard_group_id' + ) + + def test_dashboard_owner_nodes(self) -> None: + actual = self.dashboard_owner.create_next_node() + self.assertIsNone(actual) + + def test_dashboard_owner_relations(self) -> None: + + actual = self.dashboard_owner.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + expected = { + RELATION_END_KEY: 'foo@bar.com', + RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: 'User', + RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + RELATION_TYPE: 'OWNER', + RELATION_REVERSE_TYPE: 'OWNER_OF' + } + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + + def test_dashboard_owner_relations_neptune(self) -> None: + actual = self.dashboard_owner.create_next_relation() + actual_serialized = neptune_serializer.convert_relationship(actual) + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + to_vertex_id='User:foo@bar.com', + label='OWNER' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + to_vertex_id='User:foo@bar.com', + label='OWNER' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'User:foo@bar.com', + NEPTUNE_HEADER_LABEL: 'OWNER', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='User:foo@bar.com', + to_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + label='OWNER_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='User:foo@bar.com', + to_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + label='OWNER_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'User:foo@bar.com', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + NEPTUNE_HEADER_LABEL: 'OWNER_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + assert actual is not None + self.assertDictEqual(actual_serialized[0], neptune_forward_expected) + self.assertDictEqual(actual_serialized[1], neptune_reversed_expected) + + def test_dashboard_owner_record(self) -> None: + + actual = self.dashboard_owner.create_next_record() + actual_serialized = mysql_serializer.serialize_record(actual) + expected = { + 'user_rk': 'foo@bar.com', + 'dashboard_rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id' + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(self.dashboard_owner.create_next_record()) diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_query.py b/databuilder/tests/unit/models/dashboard/test_dashboard_query.py new file mode 100644 index 0000000000..67c5189164 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_query.py @@ -0,0 +1,156 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_query import DashboardQuery +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.serializers import ( + atlas_serializer, mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardQuery(unittest.TestCase): + + def setUp(self) -> None: + self.dashboard_query = DashboardQuery( + dashboard_group_id='dg_id', + dashboard_id='d_id', + query_id='q_id', + query_name='q_name', + url='http://foo.bar/query/baz', + query_text='SELECT * FROM foo.bar' + ) + + def test_create_nodes(self) -> None: + actual = self.dashboard_query.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) + expected = { + 'url': 'http://foo.bar/query/baz', + 'name': 'q_name', + 'id': 'q_id', + 'query_text': 'SELECT * FROM foo.bar', + NODE_KEY: '_dashboard://gold.dg_id/d_id/query/q_id', + NODE_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL + } + + self.assertEqual(expected, actual_serialized) + + def test_create_nodes_neptune(self) -> None: + actual = self.dashboard_query.create_next_node() + actual_serialized = neptune_serializer.convert_node(actual) + neptune_expected = { + NEPTUNE_HEADER_ID: 'Query:_dashboard://gold.dg_id/d_id/query/q_id', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: '_dashboard://gold.dg_id/d_id/query/q_id', + NEPTUNE_HEADER_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'id:String(single)': 'q_id', + 'query_text:String(single)': 'SELECT * FROM foo.bar', + 'name:String(single)': 'q_name', + 'url:String(single)': 'http://foo.bar/query/baz' + } + self.assertEqual(neptune_expected, actual_serialized) + + def test_create_relation(self) -> None: + actual = self.dashboard_query.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + expected = { + RELATION_END_KEY: '_dashboard://gold.dg_id/d_id/query/q_id', + RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, + RELATION_START_KEY: '_dashboard://gold.dg_id/d_id', + RELATION_TYPE: 'HAS_QUERY', + RELATION_REVERSE_TYPE: 'QUERY_OF' + } + + self.assertEqual(expected, actual_serialized) + + def test_create_relation_neptune(self) -> None: + actual = self.dashboard_query.create_next_relation() + actual_serialized = neptune_serializer.convert_relationship(actual) + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.dg_id/d_id', + to_vertex_id='Query:_dashboard://gold.dg_id/d_id/query/q_id', + label='HAS_QUERY' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:_dashboard://gold.dg_id/d_id', + to_vertex_id='Query:_dashboard://gold.dg_id/d_id/query/q_id', + label='HAS_QUERY' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Dashboard:_dashboard://gold.dg_id/d_id', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Query:_dashboard://gold.dg_id/d_id/query/q_id', + NEPTUNE_HEADER_LABEL: 'HAS_QUERY', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Query:_dashboard://gold.dg_id/d_id/query/q_id', + to_vertex_id='Dashboard:_dashboard://gold.dg_id/d_id', + label='QUERY_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Query:_dashboard://gold.dg_id/d_id/query/q_id', + to_vertex_id='Dashboard:_dashboard://gold.dg_id/d_id', + label='QUERY_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Query:_dashboard://gold.dg_id/d_id/query/q_id', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Dashboard:_dashboard://gold.dg_id/d_id', + NEPTUNE_HEADER_LABEL: 'QUERY_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + assert actual is not None + self.assertDictEqual(actual_serialized[0], neptune_forward_expected) + self.assertDictEqual(actual_serialized[1], neptune_reversed_expected) + + def test_create_records(self) -> None: + actual = self.dashboard_query.create_next_record() + actual_serialized = mysql_serializer.serialize_record(actual) + expected = { + 'rk': '_dashboard://gold.dg_id/d_id/query/q_id', + 'name': 'q_name', + 'id': 'q_id', + 'dashboard_rk': '_dashboard://gold.dg_id/d_id', + 'url': 'http://foo.bar/query/baz', + 'query_text': 'SELECT * FROM foo.bar' + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(self.dashboard_query.create_next_record()) + + def test_create_next_atlas_entity(self) -> None: + actual = self.dashboard_query.create_next_atlas_entity() + actual_serialized = atlas_serializer.serialize_entity(actual) + + expected = { + "typeName": "DashboardQuery", + "operation": "CREATE", + "relationships": "dashboard#Dashboard#_dashboard://gold.dg_id/d_id", + "qualifiedName": "_dashboard://gold.dg_id/d_id/query/q_id", + "name": "q_name", + "id": "q_id", + "url": "http://foo.bar/query/baz", + "queryText": "SELECT * FROM foo.bar" + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(self.dashboard_query.create_next_atlas_entity()) diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_table.py b/databuilder/tests/unit/models/dashboard/test_dashboard_table.py new file mode 100644 index 0000000000..db0c0c1703 --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_table.py @@ -0,0 +1,173 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_table import DashboardTable +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.serializers import ( + atlas_serializer, mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardTable(unittest.TestCase): + def test_dashboard_table_nodes(self) -> None: + dashboard_table = DashboardTable(table_ids=['hive://gold.schema/table1', 'hive://gold.schema/table2'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + + actual = dashboard_table.create_next_node() + self.assertIsNone(actual) + + def test_dashboard_table_relations(self) -> None: + dashboard_table = DashboardTable(table_ids=['hive://gold.schema/table1'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + + actual = dashboard_table.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + actual_neptune_serialized = neptune_serializer.convert_relationship(actual) + expected = {RELATION_END_KEY: 'hive://gold.schema/table1', RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: 'Table', + RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + RELATION_TYPE: 'DASHBOARD_WITH_TABLE', + RELATION_REVERSE_TYPE: 'TABLE_OF_DASHBOARD'} + + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + to_vertex_id='Table:hive://gold.schema/table1', + label='DASHBOARD_WITH_TABLE' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + to_vertex_id='Table:hive://gold.schema/table1', + label='DASHBOARD_WITH_TABLE' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.schema/table1', + NEPTUNE_HEADER_LABEL: 'DASHBOARD_WITH_TABLE', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.schema/table1', + to_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + label='TABLE_OF_DASHBOARD' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.schema/table1', + to_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + label='TABLE_OF_DASHBOARD' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.schema/table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + NEPTUNE_HEADER_LABEL: 'TABLE_OF_DASHBOARD', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + self.assertDictEqual(actual_neptune_serialized[0], neptune_forward_expected) + self.assertDictEqual(actual_neptune_serialized[1], neptune_reversed_expected) + + def test_dashboard_table_without_dot_as_name(self) -> None: + dashboard_table = DashboardTable(table_ids=['bq-name://project-id.schema-name/table-name'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + actual = dashboard_table.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + expected = {RELATION_END_KEY: 'bq-name://project-id.schema-name/table-name', RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: 'Table', + RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + RELATION_TYPE: 'DASHBOARD_WITH_TABLE', + RELATION_REVERSE_TYPE: 'TABLE_OF_DASHBOARD'} + assert actual is not None + self.assertDictEqual(actual_serialized, expected) + + def test_dashboard_table_with_dot_as_name(self) -> None: + dashboard_table = DashboardTable(table_ids=['bq-name://project.id.schema-name/table-name'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + actual = dashboard_table.create_next_relation() + self.assertIsNone(actual) + + def test_dashboard_table_with_slash_as_name(self) -> None: + dashboard_table = DashboardTable(table_ids=['bq/name://project/id.schema/name/table/name'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + actual = dashboard_table.create_next_relation() + self.assertIsNone(actual) + + def test_dashboard_table_records(self) -> None: + dashboard_table = DashboardTable(table_ids=['hive://gold.schema/table1', 'hive://gold.schema/table2'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') + actual1 = dashboard_table.create_next_record() + actual1_serialized = mysql_serializer.serialize_record(actual1) + expected1 = { + 'dashboard_rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + 'table_rk': 'hive://gold.schema/table1' + } + + actual2 = dashboard_table.create_next_record() + actual2_serialized = mysql_serializer.serialize_record(actual2) + expected2 = { + 'dashboard_rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + 'table_rk': 'hive://gold.schema/table2' + } + + assert actual1 is not None + self.assertDictEqual(expected1, actual1_serialized) + + assert actual2 is not None + self.assertDictEqual(expected2, actual2_serialized) + self.assertIsNone(dashboard_table.create_next_record()) + + def test_create_next_atlas_relation(self) -> None: + dashboard_table = DashboardTable( + table_ids=['hive://gold.schema/table1', 'hive_table://gold.schema/table2'], + cluster='cluster_id', product='product_id', + dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id', + ) + + # 'hive' is db name compatible with Amundsen Databuilder sourced data. in such case qn = amundsen key + # 'hive_table' is db name compatible with data sources from Atlas Hive Hook. in such case custom qn is used + expected = [ + { + "relationshipType": "Table__Dashboard", + "entityType1": "Table", + "entityQualifiedName1": "hive://gold.schema/table1", + "entityType2": "Dashboard", + "entityQualifiedName2": "product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id", + }, + { + "relationshipType": "Table__Dashboard", + "entityType1": "Table", + "entityQualifiedName1": "schema.table2@gold", + "entityType2": "Dashboard", + "entityQualifiedName2": "product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id", + }, + ] + relationship = dashboard_table.create_next_atlas_relation() # type: ignore + actual = [] + while relationship: + actual_serialized = atlas_serializer.serialize_relationship(relationship) + actual.append(actual_serialized) + relationship = dashboard_table.create_next_atlas_relation() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/dashboard/test_dashboard_usage.py b/databuilder/tests/unit/models/dashboard/test_dashboard_usage.py new file mode 100644 index 0000000000..cc77a4dd5b --- /dev/null +++ b/databuilder/tests/unit/models/dashboard/test_dashboard_usage.py @@ -0,0 +1,145 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any, Dict +from unittest.mock import ANY + +from databuilder.models.dashboard.dashboard_usage import DashboardUsage +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestDashboardOwner(unittest.TestCase): + + def test_dashboard_usage_user_nodes(self) -> None: + dashboard_usage = DashboardUsage(dashboard_group_id='dashboard_group_id', dashboard_id='dashboard_id', + email='foo@bar.com', view_count=123, cluster='cluster_id', + product='product_id', should_create_user_node=True) + + actual = dashboard_usage.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) + expected: Dict[str, Any] = { + 'LABEL': 'User', + 'KEY': 'foo@bar.com', + 'email': 'foo@bar.com', + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(dashboard_usage.create_next_node()) + + def test_dashboard_usage_no_user_nodes(self) -> None: + dashboard_usage = DashboardUsage(dashboard_group_id='dashboard_group_id', dashboard_id='dashboard_id', + email='foo@bar.com', view_count=123, + should_create_user_node=False, cluster='cluster_id', + product='product_id') + + self.assertIsNone(dashboard_usage.create_next_node()) + + def test_dashboard_owner_relations(self) -> None: + dashboard_usage = DashboardUsage(dashboard_group_id='dashboard_group_id', dashboard_id='dashboard_id', + email='foo@bar.com', view_count=123, cluster='cluster_id', + product='product_id') + + actual = dashboard_usage.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) + expected: Dict[str, Any] = { + 'read_count:UNQUOTED': 123, + RELATION_END_KEY: 'foo@bar.com', + RELATION_START_LABEL: 'Dashboard', + RELATION_END_LABEL: 'User', + RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + RELATION_TYPE: 'READ_BY', + RELATION_REVERSE_TYPE: 'READ' + } + + assert actual is not None + self.assertDictEqual(expected, actual_serialized) + self.assertIsNone(dashboard_usage.create_next_relation()) + + def test_dashboard_owner_relations_neptune(self) -> None: + dashboard_usage = DashboardUsage(dashboard_group_id='dashboard_group_id', dashboard_id='dashboard_id', + email='foo@bar.com', view_count=123, cluster='cluster_id', + product='product_id') + + actual = dashboard_usage.create_next_relation() + actual_serialized = neptune_serializer.convert_relationship(actual) + + forward_id = "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + to_vertex_id='User:foo@bar.com', + label='READ_BY' + ) + reverse_id = "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='User:foo@bar.com', + to_vertex_id='Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + label='READ' + ) + + dashboard_id = 'Dashboard:product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id' + neptune_forward_expected = { + NEPTUNE_HEADER_ID: forward_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: forward_id, + NEPTUNE_RELATIONSHIP_HEADER_FROM: dashboard_id, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'User:foo@bar.com', + NEPTUNE_HEADER_LABEL: 'READ_BY', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'read_count:Long(single)': 123 + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: reverse_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: reverse_id, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'User:foo@bar.com', + NEPTUNE_RELATIONSHIP_HEADER_TO: dashboard_id, + NEPTUNE_HEADER_LABEL: 'READ', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'read_count:Long(single)': 123, + } + + assert actual is not None + self.maxDiff = None + self.assertDictEqual(neptune_forward_expected, actual_serialized[0]) + self.assertDictEqual(neptune_reversed_expected, actual_serialized[1]) + self.assertIsNone(dashboard_usage.create_next_relation()) + + def test_dashboard_usage_user_records(self) -> None: + dashboard_usage = DashboardUsage(dashboard_group_id='dashboard_group_id', dashboard_id='dashboard_id', + email='foo@bar.com', view_count=123, cluster='cluster_id', + product='product_id', should_create_user_node=True) + + actual1 = dashboard_usage.create_next_record() + actual1_serialized = mysql_serializer.serialize_record(actual1) + expected1 = { + 'rk': 'foo@bar.com', + 'email': 'foo@bar.com', + } + + assert actual1 is not None + self.assertDictEqual(expected1, actual1_serialized) + + actual2 = dashboard_usage.create_next_record() + actual2_serialized = mysql_serializer.serialize_record(actual2) + expected2 = { + 'user_rk': 'foo@bar.com', + 'dashboard_rk': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', + 'read_count': 123 + } + + assert actual2 is not None + self.assertDictEqual(expected2, actual2_serialized) + self.assertIsNone(dashboard_usage.create_next_record()) diff --git a/databuilder/tests/unit/models/feature/test_feature_generation_code.py b/databuilder/tests/unit/models/feature/test_feature_generation_code.py new file mode 100644 index 0000000000..18555ed86a --- /dev/null +++ b/databuilder/tests/unit/models/feature/test_feature_generation_code.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.feature.feature_generation_code import FeatureGenerationCode +from databuilder.serializers import neo4_serializer + + +class TestFeatureGenerationCode(unittest.TestCase): + def setUp(self) -> None: + self.gencode = FeatureGenerationCode( + feature_group='group1', + feature_name='feat_name_123', + feature_version='2.0.0', + text='select * from hive.schema.table', + source='foobar', + last_executed_timestamp=1622596581, + ) + + self.expected_nodes = [ + { + 'KEY': 'group1/feat_name_123/2.0.0/_generation_code', + 'LABEL': 'Feature_Generation_Code', + 'text': 'select * from hive.schema.table', + 'source': 'foobar', + 'last_executed_timestamp:UNQUOTED': 1622596581, + } + ] + + self.expected_rels = [ + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Feature_Generation_Code', + 'START_KEY': 'group1/feat_name_123/2.0.0', + 'END_KEY': 'group1/feat_name_123/2.0.0/_generation_code', + 'TYPE': 'GENERATION_CODE', + 'REVERSE_TYPE': 'GENERATION_CODE_OF', + } + ] + + def test_basic_example(self) -> None: + node_row = self.gencode.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.gencode.next_node() + + self.assertEqual(self.expected_nodes, actual) + + relation_row = self.gencode.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.gencode.next_relation() + self.assertEqual(self.expected_rels, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/feature/test_feature_metadata.py b/databuilder/tests/unit/models/feature/test_feature_metadata.py new file mode 100644 index 0000000000..b68a584ffb --- /dev/null +++ b/databuilder/tests/unit/models/feature/test_feature_metadata.py @@ -0,0 +1,193 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.feature.feature_metadata import FeatureMetadata +from databuilder.serializers import neo4_serializer + + +class TestFeatureMetadata(unittest.TestCase): + def setUp(self) -> None: + # reset node cache + FeatureMetadata.processed_feature_group_keys = set() + FeatureMetadata.processed_database_keys = set() + + self.full_metadata = FeatureMetadata( + feature_group='My Feature Group', + name='feature_123', + version='2.0.0', + status='ready', + entity='Buyer', + data_type='float', + availability=['hive', 'dynamo'], + description='My awesome feature', + tags=['qa passed', 'core'], + created_timestamp=1622596581, + ) + + self.required_only_metadata = FeatureMetadata( + feature_group='My Feature Group', + name='feature_123', + version='2.0.0', + ) + + self.expected_nodes_full = [ + { + 'KEY': 'My Feature Group/feature_123/2.0.0', + 'LABEL': 'Feature', + 'name': 'feature_123', + 'version': '2.0.0', + 'status': 'ready', + 'entity': 'Buyer', + 'data_type': 'float', + 'created_timestamp:UNQUOTED': 1622596581, + }, + { + 'KEY': 'My Feature Group', + 'LABEL': 'Feature_Group', + 'name': 'My Feature Group', + }, + { + 'KEY': 'My Feature Group/feature_123/2.0.0/_description', + 'LABEL': 'Description', + 'description_source': 'description', + 'description': 'My awesome feature', + }, + { + 'KEY': 'database://hive', + 'LABEL': 'Database', + 'name': 'hive', + }, + { + 'KEY': 'database://dynamo', + 'LABEL': 'Database', + 'name': 'dynamo', + }, + { + 'KEY': 'qa passed', + 'LABEL': 'Tag', + 'tag_type': 'default', + }, + { + 'KEY': 'core', + 'LABEL': 'Tag', + 'tag_type': 'default', + }, + ] + + self.expected_rels_full = [ + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Feature_Group', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'My Feature Group', + 'TYPE': 'GROUPED_BY', + 'REVERSE_TYPE': 'GROUPS', + }, + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Description', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'My Feature Group/feature_123/2.0.0/_description', + 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF', + }, + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Database', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'database://hive', + 'TYPE': 'FEATURE_AVAILABLE_IN', + 'REVERSE_TYPE': 'AVAILABLE_FEATURE', + }, + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Database', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'database://dynamo', + 'TYPE': 'FEATURE_AVAILABLE_IN', + 'REVERSE_TYPE': 'AVAILABLE_FEATURE', + }, + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Tag', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'qa passed', + 'TYPE': 'TAGGED_BY', + 'REVERSE_TYPE': 'TAG', + }, + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Tag', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'core', + 'TYPE': 'TAGGED_BY', + 'REVERSE_TYPE': 'TAG', + } + ] + + self.expected_nodes_required_only = [ + { + 'KEY': 'My Feature Group/feature_123/2.0.0', + 'LABEL': 'Feature', + 'name': 'feature_123', + 'version': '2.0.0', + }, + { + 'KEY': 'My Feature Group', + 'LABEL': 'Feature_Group', + 'name': 'My Feature Group', + }, + ] + + self.expected_rels_required_only = [ + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Feature_Group', + 'START_KEY': 'My Feature Group/feature_123/2.0.0', + 'END_KEY': 'My Feature Group', + 'TYPE': 'GROUPED_BY', + 'REVERSE_TYPE': 'GROUPS', + }, + ] + + def test_full_example(self) -> None: + node_row = self.full_metadata.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.full_metadata.next_node() + + self.assertEqual(self.expected_nodes_full, actual) + + relation_row = self.full_metadata.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.full_metadata.next_relation() + self.assertEqual(self.expected_rels_full, actual) + + def test_required_only_example(self) -> None: + node_row = self.required_only_metadata.next_node() + actual = [] + while node_row: + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) + node_row = self.required_only_metadata.next_node() + + self.assertEqual(self.expected_nodes_required_only, actual) + + relation_row = self.required_only_metadata.next_relation() + actual = [] + while relation_row: + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) + relation_row = self.required_only_metadata.next_relation() + self.assertEqual(self.expected_rels_required_only, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/feature/test_feature_watermark.py b/databuilder/tests/unit/models/feature/test_feature_watermark.py new file mode 100644 index 0000000000..d87454ddea --- /dev/null +++ b/databuilder/tests/unit/models/feature/test_feature_watermark.py @@ -0,0 +1,61 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.feature.feature_watermark import FeatureWatermark +from databuilder.serializers import neo4_serializer + + +class TestFeatureWatermark(unittest.TestCase): + def setUp(self) -> None: + self.watermark = FeatureWatermark( + feature_group='group1', + feature_name='feat_name_123', + feature_version='2.0.0', + timestamp=1622596581, + wm_type='low_watermark' + ) + + self.expected_nodes = [ + { + 'KEY': 'group1/feat_name_123/2.0.0/low_watermark', + 'LABEL': 'Feature_Watermark', + 'timestamp:UNQUOTED': 1622596581, + 'watermark_type': 'low_watermark', + } + ] + + self.expected_rels = [ + { + 'START_LABEL': 'Feature', + 'END_LABEL': 'Feature_Watermark', + 'START_KEY': 'group1/feat_name_123/2.0.0', + 'END_KEY': 'group1/feat_name_123/2.0.0/low_watermark', + 'TYPE': 'WATERMARK', + 'REVERSE_TYPE': 'BELONG_TO_FEATURE', + } + ] + + def test_basic_feature_watermark(self) -> None: + node = self.watermark.next_node() + actual = [] + while node: + node_serialized = neo4_serializer.serialize_node(node) + actual.append(node_serialized) + node = self.watermark.next_node() + + self.assertEqual(self.expected_nodes, actual) + + relation = self.watermark.next_relation() + actual = [] + while relation: + relation_serialized = neo4_serializer.serialize_relationship(relation) + actual.append(relation_serialized) + relation = self.watermark.next_relation() + + self.assertEqual(self.expected_rels, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/query/__init__.py b/databuilder/tests/unit/models/query/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/models/query/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/models/query/test_base.py b/databuilder/tests/unit/models/query/test_base.py new file mode 100644 index 0000000000..08b1424eea --- /dev/null +++ b/databuilder/tests/unit/models/query/test_base.py @@ -0,0 +1,30 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.query.base import QueryBase + + +class TestQueryBase(unittest.TestCase): + def test_normalize_mixed_case(self) -> None: + query = 'SELECT foo from BAR' + expected = 'select foo from bar' + self.assertEqual(QueryBase._normalize(query), expected) + + def test_normalize_mixed_case_string(self) -> None: + query = "SELECT 'foo BaR'" + expected = "select 'foo BaR'" + self.assertEqual(QueryBase._normalize(query), expected) + + def test_normalize_lots_of_space(self) -> None: + query = ''' + SELECT foo AS bar + FROM baz''' + expected = 'select foo as bar from baz' + self.assertEqual(QueryBase._normalize(query), expected) + + def test_trailing_semicolon(self) -> None: + query = "select 'a;b;c';" + expected = "select 'a;b;c'" + self.assertEqual(QueryBase._normalize(query), expected) diff --git a/databuilder/tests/unit/models/query/test_query.py b/databuilder/tests/unit/models/query/test_query.py new file mode 100644 index 0000000000..b31aab99dc --- /dev/null +++ b/databuilder/tests/unit/models/query/test_query.py @@ -0,0 +1,120 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.query import QueryMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.user import User +from databuilder.serializers import neo4_serializer + + +class TestQuery(unittest.TestCase): + + def setUp(self) -> None: + self.maxDiff = None + super(TestQuery, self).setUp() + self.user = User(first_name='test_first', + last_name='test_last', + full_name='test_first test_last', + email='test@email.com', + github_username='github_test', + team_name='test_team', + employee_type='FTE', + manager_email='test_manager@email.com', + slack_id='slack', + is_active=True, + profile_url='https://profile', + updated_at=1, + role_name='swe') + self.table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1', + [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0), + ColumnMetadata('test_id2', 'description of test_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5) + ] + ) + self.sql = "select * from table" + self.query_metadata = QueryMetadata(sql=self.sql, + tables=[self.table_metadata], + user=self.user) + self._query_hash = 'da44ff72560e593a8eca9ffcee6a2696' + + def test_get_model_key(self) -> None: + key = QueryMetadata.get_key(sql_hash=self.query_metadata.sql_hash) + self.assertEqual(key, self._query_hash) + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': 'Query', + 'KEY': self._query_hash, + 'sql': self.sql + }] + + actual = [] + node = self.query_metadata.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.query_metadata.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_relation(self) -> None: + actual = [] + relation = self.query_metadata.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.query_metadata.create_next_relation() + + expected_relations = [ + { + RELATION_START_KEY: 'hive://gold.test_schema1/test_table1', + RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, + RELATION_END_KEY: self._query_hash, + RELATION_END_LABEL: QueryMetadata.NODE_LABEL, + RELATION_TYPE: QueryMetadata.TABLE_QUERY_RELATION_TYPE, + RELATION_REVERSE_TYPE: QueryMetadata.INVERSE_TABLE_QUERY_RELATION_TYPE + }, + { + RELATION_START_KEY: 'test@email.com', + RELATION_START_LABEL: User.USER_NODE_LABEL, + RELATION_END_KEY: self._query_hash, + RELATION_END_LABEL: QueryMetadata.NODE_LABEL, + RELATION_TYPE: QueryMetadata.USER_QUERY_RELATION_TYPE, + RELATION_REVERSE_TYPE: QueryMetadata.INVERSE_USER_QUERY_RELATION_TYPE + } + ] + + self.assertEquals(expected_relations, actual) + + def test_keys_of_query_containing_strings_with_spaces(self) -> None: + query_metadata1 = QueryMetadata(sql="select * from table a where a.field == 'xyz'", + tables=[self.table_metadata]) + + query_metadata2 = QueryMetadata(sql="select * from table a where a.field == 'x y z'", + tables=[self.table_metadata]) + + self.assertNotEqual(query_metadata1.get_key_self(), query_metadata2.get_key_self()) + + def test_keys_of_query_containing_strings_with_mixed_case(self) -> None: + query_metadata1 = QueryMetadata(sql="select * from table a where a.field == 'x Y z'", + tables=[self.table_metadata]) + + query_metadata2 = QueryMetadata(sql="select * from table a where a.field == 'x y z'", + tables=[self.table_metadata]) + + self.assertNotEqual(query_metadata1.get_key_self(), query_metadata2.get_key_self()) diff --git a/databuilder/tests/unit/models/query/test_query_execution.py b/databuilder/tests/unit/models/query/test_query_execution.py new file mode 100644 index 0000000000..2338527500 --- /dev/null +++ b/databuilder/tests/unit/models/query/test_query_execution.py @@ -0,0 +1,83 @@ + +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.query.query import QueryMetadata +from databuilder.models.query.query_execution import QueryExecutionsMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.serializers import neo4_serializer + + +class TestQueryExecution(unittest.TestCase): + + def setUp(self) -> None: + super(TestQueryExecution, self).setUp() + # Display full diffs + self.maxDiff = None + self.table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1', + [ + ColumnMetadata('field', '', '', 0), + ] + ) + self.query_metadata = QueryMetadata(sql="select * from table a where a.field > 3", + tables=[self.table_metadata]) + + self.query_join_metadata = QueryExecutionsMetadata(query_metadata=self.query_metadata, + start_time=10, + execution_count=7) + self._expected_key = '748c28f86de411b1d2b9deb6ae105eba-10' + + def test_get_model_key(self) -> None: + key = QueryExecutionsMetadata.get_key(query_key=self.query_metadata.get_key_self(), start_time=10) + + self.assertEqual(key, self._expected_key) + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': QueryExecutionsMetadata.NODE_LABEL, + 'KEY': self._expected_key, + 'execution_count:UNQUOTED': 7, + 'start_time:UNQUOTED': 10, + 'window_duration': 'daily' + }] + + actual = [] + node = self.query_join_metadata.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.query_join_metadata.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_relation(self) -> None: + actual = [] + relation = self.query_join_metadata.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.query_join_metadata.create_next_relation() + + self.maxDiff = None + expected_relations = [ + { + RELATION_END_KEY: self._expected_key, + RELATION_END_LABEL: QueryExecutionsMetadata.NODE_LABEL, + RELATION_REVERSE_TYPE: QueryExecutionsMetadata.INVERSE_QUERY_EXECUTION_RELATION_TYPE, + RELATION_START_KEY: self.query_metadata.get_key_self(), + RELATION_START_LABEL: QueryMetadata.NODE_LABEL, + RELATION_TYPE: QueryExecutionsMetadata.QUERY_EXECUTION_RELATION_TYPE + } + ] + self.assertEquals(expected_relations, actual) diff --git a/databuilder/tests/unit/models/query/test_query_join.py b/databuilder/tests/unit/models/query/test_query_join.py new file mode 100644 index 0000000000..70f04b7136 --- /dev/null +++ b/databuilder/tests/unit/models/query/test_query_join.py @@ -0,0 +1,130 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.query.query import QueryMetadata +from databuilder.models.query.query_join import QueryJoinMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.serializers import neo4_serializer + + +class TestQueryJoin(unittest.TestCase): + + def setUp(self) -> None: + super(TestQueryJoin, self).setUp() + # Display full diffs + self.maxDiff = None + self.tbl1_col = ColumnMetadata('field', '', '', 0) + self.left_table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1 desc', + [self.tbl1_col] + ) + self.tbl2_col = ColumnMetadata('field', '', '', 0) + self.right_table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table2', + 'test_table2 desc', + [self.tbl2_col] + ) + self.query_metadata = QueryMetadata(sql="select * from table a where a.field > 3", + tables=[self.left_table_metadata, self.right_table_metadata]) + + self.query_join_metadata = QueryJoinMetadata( + left_table=self.left_table_metadata, + right_table=self.right_table_metadata, + left_column=self.tbl1_col, + right_column=self.tbl2_col, + join_type='inner join', + join_operator='=', + join_sql='test_table1 = join test_table2 on test_tabl1.field = test_table2.field', + query_metadata=self.query_metadata + ) + self._expected_key = ( + 'inner-join-' + 'hive://gold.test_schema1/test_table1/field-' + '=-' + 'hive://gold.test_schema1/test_table2/field' + ) + + def test_get_model_key(self) -> None: + key = QueryJoinMetadata.get_key(left_column_key=self.left_table_metadata._get_col_key(col=self.tbl1_col), + right_column_key=self.right_table_metadata._get_col_key(col=self.tbl2_col), + join_type='inner join', + operator='=') + + self.assertEqual(key, self._expected_key) + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': 'Join', + 'KEY': self._expected_key, + 'join_sql': 'test_table1 = join test_table2 on test_tabl1.field = test_table2.field', + 'join_type': 'inner join', + 'left_cluster': 'gold', + 'left_database': 'hive', + 'left_schema': 'test_schema1', + 'left_table': 'test_table1', + 'left_table_key': 'hive://gold.test_schema1/test_table1', + 'operator': '=', + 'right_cluster': 'gold', + 'right_database': 'hive', + 'right_schema': 'test_schema1', + 'right_table': 'test_table2', + 'right_table_key': 'hive://gold.test_schema1/test_table2' + }] + + actual = [] + node = self.query_join_metadata.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.query_join_metadata.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_relation(self) -> None: + actual = [] + relation = self.query_join_metadata.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.query_join_metadata.create_next_relation() + + expected_relations = [ + { + RELATION_END_KEY: self._expected_key, + RELATION_END_LABEL: QueryJoinMetadata.NODE_LABEL, + RELATION_REVERSE_TYPE: QueryJoinMetadata.INVERSE_COLUMN_JOIN_RELATION_TYPE, + RELATION_START_KEY: 'hive://gold.test_schema1/test_table1/field', + RELATION_START_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, + RELATION_TYPE: QueryJoinMetadata.COLUMN_JOIN_RELATION_TYPE + }, + { + RELATION_END_KEY: self._expected_key, + RELATION_END_LABEL: QueryJoinMetadata.NODE_LABEL, + RELATION_REVERSE_TYPE: QueryJoinMetadata.INVERSE_COLUMN_JOIN_RELATION_TYPE, + RELATION_START_KEY: 'hive://gold.test_schema1/test_table2/field', + RELATION_START_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, + RELATION_TYPE: QueryJoinMetadata.COLUMN_JOIN_RELATION_TYPE + }, + { + RELATION_END_KEY: self._expected_key, + RELATION_END_LABEL: QueryJoinMetadata.NODE_LABEL, + RELATION_REVERSE_TYPE: QueryJoinMetadata.INVERSE_QUERY_JOIN_RELATION_TYPE, + RELATION_START_KEY: '748c28f86de411b1d2b9deb6ae105eba', + RELATION_START_LABEL: QueryMetadata.NODE_LABEL, + RELATION_TYPE: QueryJoinMetadata.QUERY_JOIN_RELATION_TYPE + } + ] + self.assertEquals(expected_relations, actual) diff --git a/databuilder/tests/unit/models/query/test_query_where.py b/databuilder/tests/unit/models/query/test_query_where.py new file mode 100644 index 0000000000..9413ef8fc8 --- /dev/null +++ b/databuilder/tests/unit/models/query/test_query_where.py @@ -0,0 +1,93 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.query.query import QueryMetadata +from databuilder.models.query.query_where import QueryWhereMetadata +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.serializers import neo4_serializer + + +class TestQueryWhere(unittest.TestCase): + + def setUp(self) -> None: + super(TestQueryWhere, self).setUp() + # Display full diffs + self.maxDiff = None + self.table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1', + [ + ColumnMetadata('field', '', '', 0), + ] + ) + self.query_metadata = QueryMetadata(sql="select * from table a where a.field > 3", + tables=[self.table_metadata]) + + self.query_where_metadata = QueryWhereMetadata(tables=[self.table_metadata], + where_clause='a.field > 3', + left_arg='field', + right_arg='3', + operator='>', + query_metadata=self.query_metadata) + self._expected_key_hash = '795a2a16184c09b88ae518cd5230cfb5-be8634550905b354508dc8aba8008c14' + + def test_get_model_key(self) -> None: + key = QueryWhereMetadata.get_key(table_hash=self.query_where_metadata._table_hash, + where_hash=self.query_where_metadata._where_hash) + self.assertEqual(key, self._expected_key_hash) + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': 'Where', + 'KEY': self._expected_key_hash, + 'left_arg': 'field', + 'operator': '>', + 'right_arg': '3', + 'where_clause': 'a.field > 3' + }] + + actual = [] + node = self.query_where_metadata.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.query_where_metadata.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_relation(self) -> None: + actual = [] + relation = self.query_where_metadata.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.query_where_metadata.create_next_relation() + + expected_relations = [ + { + RELATION_START_KEY: 'hive://gold.test_schema1/test_table1/field', + RELATION_START_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, + RELATION_END_KEY: self._expected_key_hash, + RELATION_END_LABEL: QueryWhereMetadata.NODE_LABEL, + RELATION_TYPE: QueryWhereMetadata.COLUMN_WHERE_RELATION_TYPE, + RELATION_REVERSE_TYPE: QueryWhereMetadata.INVERSE_COLUMN_WHERE_RELATION_TYPE + }, + { + RELATION_START_KEY: self.query_metadata.get_key_self(), + RELATION_START_LABEL: QueryMetadata.NODE_LABEL, + RELATION_END_KEY: self.query_where_metadata.get_key_self(), + RELATION_END_LABEL: QueryWhereMetadata.NODE_LABEL, + RELATION_TYPE: QueryWhereMetadata.QUERY_WHERE_RELATION_TYPE, + RELATION_REVERSE_TYPE: QueryWhereMetadata.INVERSE_QUERY_WHERE_RELATION_TYPE + } + ] + self.assertEquals(expected_relations, actual) diff --git a/databuilder/tests/unit/models/schema/__init__.py b/databuilder/tests/unit/models/schema/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/models/schema/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/models/schema/test_schema.py b/databuilder/tests/unit/models/schema/test_schema.py new file mode 100644 index 0000000000..41d8ddb288 --- /dev/null +++ b/databuilder/tests/unit/models/schema/test_schema.py @@ -0,0 +1,222 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.schema.schema import SchemaModel +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestSchemaDescription(unittest.TestCase): + def setUp(self) -> None: + self.schema = SchemaModel( + schema_key='db://cluster.schema', + schema='schema_name', + description='foo' + ) + + def test_create_nodes(self) -> None: + schema_node = self.schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + schema_desc_node = self.schema.create_next_node() + serialized_schema_desc_node = neo4_serializer.serialize_node(schema_desc_node) + self.assertDictEqual( + serialized_schema_node, + {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'} + ) + self.assertDictEqual(serialized_schema_desc_node, + {'description_source': 'description', 'description': 'foo', + 'KEY': 'db://cluster.schema/_description', 'LABEL': 'Description'} + ) + self.assertIsNone(self.schema.create_next_node()) + + def test_create_nodes_neptune(self) -> None: + schema_node = self.schema.create_next_node() + expected_serialized_schema_node = { + NEPTUNE_HEADER_ID: 'Schema:db://cluster.schema', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'db://cluster.schema', + NEPTUNE_HEADER_LABEL: 'Schema', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'schema_name', + } + serialized_schema_node = neptune_serializer.convert_node(schema_node) + self.assertDictEqual( + expected_serialized_schema_node, + serialized_schema_node + ) + schema_desc_node = self.schema.create_next_node() + excepted_serialized_schema_desc_node = { + NEPTUNE_HEADER_ID: 'Description:db://cluster.schema/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'db://cluster.schema/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description', + 'description:String(single)': 'foo', + } + serialized_schema_desc_node = neptune_serializer.convert_node(schema_desc_node) + self.assertDictEqual( + excepted_serialized_schema_desc_node, + serialized_schema_desc_node + ) + + def test_create_nodes_no_description(self) -> None: + + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name') + + schema_node = schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + + self.assertDictEqual(serialized_schema_node, + {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'}) + self.assertIsNone(schema.create_next_node()) + + def test_create_nodes_programmatic_description(self) -> None: + + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name', + description='foo', + description_source='bar') + + schema_node = schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + schema_desc_node = schema.create_next_node() + serialized_schema_prod_desc_node = neo4_serializer.serialize_node(schema_desc_node) + + self.assertDictEqual(serialized_schema_node, + {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'}) + self.assertDictEqual(serialized_schema_prod_desc_node, + {'description_source': 'bar', 'description': 'foo', + 'KEY': 'db://cluster.schema/_bar_description', 'LABEL': 'Programmatic_Description'}) + self.assertIsNone(schema.create_next_node()) + + def test_create_relation(self) -> None: + actual = self.schema.create_next_relation() + serialized_actual = neo4_serializer.serialize_relationship(actual) + expected = {'END_KEY': 'db://cluster.schema/_description', 'START_LABEL': 'Schema', 'END_LABEL': 'Description', + 'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + + self.assertEqual(expected, serialized_actual) + self.assertIsNone(self.schema.create_next_relation()) + + def test_create_relation_neptune(self) -> None: + actual = self.schema.create_next_relation() + serialized_actual = neptune_serializer.convert_relationship(actual) + forward_header_id = "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Schema:db://cluster.schema', + to_vertex_id='Description:db://cluster.schema/_description', + label='DESCRIPTION' + ) + reverse_header_id = "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:db://cluster.schema/_description', + to_vertex_id='Schema:db://cluster.schema', + label='DESCRIPTION_OF' + ) + + neptune_forward_expected = { + NEPTUNE_HEADER_ID: forward_header_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: forward_header_id, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Schema:db://cluster.schema', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:db://cluster.schema/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: reverse_header_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: reverse_header_id, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:db://cluster.schema/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Schema:db://cluster.schema', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + self.maxDiff = None + self.assertDictEqual(serialized_actual[0], neptune_forward_expected) + self.assertDictEqual(serialized_actual[1], neptune_reversed_expected) + + def test_create_relation_no_description(self) -> None: + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name') + + self.assertIsNone(schema.create_next_relation()) + + def test_create_relation_programmatic_description(self) -> None: + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name', + description='foo', + description_source='bar') + + actual = schema.create_next_relation() + serialized_actual = neo4_serializer.serialize_relationship(actual) + expected = { + 'END_KEY': 'db://cluster.schema/_bar_description', 'START_LABEL': 'Schema', + 'END_LABEL': 'Programmatic_Description', 'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF' + } + + self.assertEqual(expected, serialized_actual) + self.assertIsNone(schema.create_next_relation()) + + def test_create_records(self) -> None: + schema_record = self.schema.create_next_record() + serialized_schema_record = mysql_serializer.serialize_record(schema_record) + schema_desc_record = self.schema.create_next_record() + serialized_schema_desc_record = mysql_serializer.serialize_record(schema_desc_record) + self.assertDictEqual(serialized_schema_record, {'rk': 'db://cluster.schema', 'name': 'schema_name', + 'cluster_rk': 'db://cluster'}) + self.assertDictEqual(serialized_schema_desc_record, {'rk': 'db://cluster.schema/_description', + 'description_source': 'description', 'description': 'foo', + 'schema_rk': 'db://cluster.schema'}) + + def test_create_records_no_description(self) -> None: + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name') + schema_record = schema.create_next_record() + serialized_schema_record = mysql_serializer.serialize_record(schema_record) + self.assertDictEqual(serialized_schema_record, {'rk': 'db://cluster.schema', 'name': 'schema_name', + 'cluster_rk': 'db://cluster'}) + self.assertIsNone(schema.create_next_record()) + + def test_create_records_programmatic_description(self) -> None: + schema = SchemaModel(schema_key='db://cluster.schema', + schema='schema_name', + description='foo', + description_source='bar') + + schema_record = schema.create_next_record() + serialized_schema_record = mysql_serializer.serialize_record(schema_record) + schema_prog_desc_record = schema.create_next_record() + serialized_schema_prog_desc_record = mysql_serializer.serialize_record(schema_prog_desc_record) + self.assertDictEqual(serialized_schema_record, {'rk': 'db://cluster.schema', 'name': 'schema_name', + 'cluster_rk': 'db://cluster'}) + self.assertDictEqual(serialized_schema_prog_desc_record, {'rk': 'db://cluster.schema/_bar_description', + 'description_source': 'bar', + 'description': 'foo', + 'schema_rk': 'db://cluster.schema'}) + + def test_get_cluster_key(self) -> None: + schema_key = 'a123b_staging://cluster.schema' + schema_name = 'schema_name' + schema = SchemaModel(schema_key=schema_key, schema=schema_name) + assert schema._get_cluster_key(schema_key) == 'a123b_staging://cluster' + + failed_schema_key_1 = 'a123b.staging://cluster.schema' + self.assertRaises(Exception, SchemaModel, schema_key=failed_schema_key_1, schema_name=schema_name) + + failed_schema_key_2 = 'a123b-staging://cluster.schema' + self.assertRaises(Exception, SchemaModel, schema_key=failed_schema_key_2, schema_name=schema_name) diff --git a/databuilder/tests/unit/models/test_application.py b/databuilder/tests/unit/models/test_application.py new file mode 100644 index 0000000000..827ad93643 --- /dev/null +++ b/databuilder/tests/unit/models/test_application.py @@ -0,0 +1,286 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from collections import namedtuple +from dataclasses import dataclass +from typing import Dict, List +from unittest.mock import ANY + +from databuilder.models.application import Application, GenericApplication +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.models.table_metadata import TableMetadata +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +@dataclass +class ApplicationTestCase: + application: GenericApplication + expected_node_results: List[Dict] + expected_relation_results: List[Dict] + expected_records: List[Dict] + + +class TestApplication(unittest.TestCase): + + def setUp(self) -> None: + super(TestApplication, self).setUp() + + self.test_cases = [] + + # Explicitly add test case for Airflow to verify backwards compatibility + airflow_application = Application( + task_id='hive.default.test_table', + dag_id='event_test', + schema='default', + table_name='test_table', + application_url_template='airflow_host.net/admin/airflow/tree?dag_id={dag_id}', + ) + + airflow_expected_node_results = [{ + NODE_KEY: 'application://gold.airflow/event_test/hive.default.test_table', + NODE_LABEL: 'Application', + 'application_url': 'airflow_host.net/admin/airflow/tree?dag_id=event_test', + 'id': 'event_test/hive.default.test_table', + 'name': 'Airflow', + 'description': 'Airflow with id event_test/hive.default.test_table' + }] + + airflow_expected_relation_results = [{ + RELATION_START_KEY: 'hive://gold.default/test_table', + RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, + RELATION_END_KEY: 'application://gold.airflow/event_test/hive.default.test_table', + RELATION_END_LABEL: 'Application', + RELATION_TYPE: 'DERIVED_FROM', + RELATION_REVERSE_TYPE: 'GENERATES' + }] + + airflow_expected_application_record = { + 'rk': 'application://gold.airflow/event_test/hive.default.test_table', + 'application_url': 'airflow_host.net/admin/airflow/tree?dag_id=event_test', + 'id': 'event_test/hive.default.test_table', + 'name': 'Airflow', + 'description': 'Airflow with id event_test/hive.default.test_table' + } + + airflow_expected_application_table_record = { + 'rk': 'hive://gold.default/test_table', + 'application_rk': 'application://gold.airflow/event_test/hive.default.test_table' + } + + airflow_expected_records = [ + airflow_expected_application_record, + airflow_expected_application_table_record, + ] + + self.test_cases.append( + ApplicationTestCase( + airflow_application, + airflow_expected_node_results, + airflow_expected_relation_results, + airflow_expected_records, + ), + ) + + # Test several non-airflow applications + AppTestCase = namedtuple('AppTestCase', ['name', 'generates_table']) + non_airflow_cases = [ + AppTestCase(name='Databricks', generates_table=False), + AppTestCase(name='Snowflake', generates_table=True), + AppTestCase(name='EMR', generates_table=False), + ] + + for case in non_airflow_cases: + application_type = case.name + url = f'https://{application_type.lower()}.com/job/1234' + id = f'{application_type}.hive.test_table' + description = f'{application_type} application for hive.test_table' + table_key = TableMetadata.TABLE_KEY_FORMAT.format( + db='hive', + cluster='gold', + schema='default', + tbl='test_table', + ) + + application = GenericApplication( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=table_key, + application_type=application_type, + application_id=id, + application_url=url, + application_description=description, + app_key_override=f'application://{application_type}/hive/test_table', + generates_resource=case.generates_table, + ) + + expected_node_results = [{ + NODE_KEY: f'application://{application_type}/hive/test_table', + NODE_LABEL: 'Application', + 'application_url': url, + 'id': id, + 'name': application_type, + 'description': description, + }] + + expected_relation_results = [{ + RELATION_START_KEY: 'hive://gold.default/test_table', + RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, + RELATION_END_KEY: f'application://{application_type}/hive/test_table', + RELATION_END_LABEL: 'Application', + RELATION_TYPE: (GenericApplication.DERIVED_FROM_REL_TYPE if case.generates_table + else GenericApplication.CONSUMED_BY_REL_TYPE), + RELATION_REVERSE_TYPE: (GenericApplication.GENERATES_REL_TYPE if case.generates_table + else GenericApplication.CONSUMES_REL_TYPE), + }] + + expected_application_record = { + 'rk': f'application://{application_type}/hive/test_table', + 'application_url': url, + 'id': id, + 'name': application_type, + 'description': description, + } + + expected_application_table_record = { + 'rk': 'hive://gold.default/test_table', + 'application_rk': f'application://{application_type}/hive/test_table' + } + + expected_records = [ + expected_application_record, + expected_application_table_record + ] + + self.test_cases.append( + ApplicationTestCase( + application, + expected_node_results, + expected_relation_results, + expected_records, + ), + ) + + def test_get_application_model_key(self) -> None: + for tc in self.test_cases: + self.assertEqual(tc.application.application_key, tc.expected_node_results[0][NODE_KEY]) + + def test_create_nodes(self) -> None: + for tc in self.test_cases: + actual = [] + node = tc.application.create_next_node() + while node: + serialized_next_node = neo4_serializer.serialize_node(node) + actual.append(serialized_next_node) + node = tc.application.create_next_node() + + self.assertEqual(actual, tc.expected_node_results) + + def test_create_nodes_neptune(self) -> None: + for tc in self.test_cases: + actual = [] + next_node = tc.application.create_next_node() + while next_node: + serialized_next_node = neptune_serializer.convert_node(next_node) + actual.append(serialized_next_node) + next_node = tc.application.create_next_node() + + node_id = f'Application:{tc.application.application_key}' + node_key = tc.application.application_key + neptune_expected = [{ + NEPTUNE_HEADER_ID: node_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: node_key, + NEPTUNE_HEADER_LABEL: 'Application', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'application_url:String(single)': tc.application.application_url, + 'id:String(single)': tc.application.application_id, + 'name:String(single)': tc.application.application_type, + 'description:String(single)': tc.application.application_description, + }] + self.assertEqual(neptune_expected, actual) + + def test_create_relation(self) -> None: + for tc in self.test_cases: + actual = [] + relation = tc.application.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = tc.application.create_next_relation() + + self.assertEqual(actual, tc.expected_relation_results) + + def test_create_relations_neptune(self) -> None: + for tc in self.test_cases: + application_id = f'Application:{tc.application.application_key}' + table_id = 'Table:hive://gold.default/test_table' + neptune_forward_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=table_id, + to_vertex_id=application_id, + label=tc.expected_relation_results[0][RELATION_TYPE], + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=table_id, + to_vertex_id=application_id, + label=tc.expected_relation_results[0][RELATION_TYPE], + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: table_id, + NEPTUNE_RELATIONSHIP_HEADER_TO: application_id, + NEPTUNE_HEADER_LABEL: tc.expected_relation_results[0][RELATION_TYPE], + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=application_id, + to_vertex_id=table_id, + label=tc.expected_relation_results[0][RELATION_REVERSE_TYPE], + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=application_id, + to_vertex_id=table_id, + label=tc.expected_relation_results[0][RELATION_REVERSE_TYPE], + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: application_id, + NEPTUNE_RELATIONSHIP_HEADER_TO: table_id, + NEPTUNE_HEADER_LABEL: tc.expected_relation_results[0][RELATION_REVERSE_TYPE], + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + neptune_expected = [[neptune_forward_expected, neptune_reversed_expected]] + + actual = [] + next_relation = tc.application.create_next_relation() + while next_relation: + serialized_next_relation = neptune_serializer.convert_relationship(next_relation) + actual.append(serialized_next_relation) + next_relation = tc.application.create_next_relation() + + self.assertEqual(actual, neptune_expected) + + def test_create_records(self) -> None: + for tc in self.test_cases: + expected = tc.expected_records + + actual = [] + record = tc.application.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = tc.application.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_atlas_serializable.py b/databuilder/tests/unit/models/test_atlas_serializable.py new file mode 100644 index 0000000000..2b29000808 --- /dev/null +++ b/databuilder/tests/unit/models/test_atlas_serializable.py @@ -0,0 +1,227 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import unittest +from typing import ( + Iterable, Iterator, Union, +) + +from amundsen_common.utils.atlas import AtlasCommonParams + +from databuilder.models.atlas_entity import AtlasEntity +from databuilder.models.atlas_relationship import AtlasRelationship +from databuilder.models.atlas_serializable import AtlasSerializable +from databuilder.serializers import atlas_serializer +from databuilder.utils.atlas import AtlasSerializedEntityFields, AtlasSerializedEntityOperation + + +class TestSerialize(unittest.TestCase): + + def test_serialize(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + cities = [City('San Diego'), City('Oakland')] + movie = Movie('Top Gun', actors, cities) + + actual = [] + entity = movie.next_atlas_entity() + while entity: + actual.append(atlas_serializer.serialize_entity(entity)) + entity = movie.next_atlas_entity() + + expected = [ + { + 'name': 'Tom Cruise', + 'operation': 'CREATE', + 'qualifiedName': 'actor://Tom Cruise', + 'relationships': None, + 'typeName': 'Actor', + }, + { + 'name': 'Meg Ryan', + 'operation': 'CREATE', + 'qualifiedName': 'actor://Meg Ryan', + 'relationships': None, + 'typeName': 'Actor', + }, + { + 'name': 'San Diego', + 'operation': 'CREATE', + 'qualifiedName': 'city://San Diego', + 'relationships': None, + 'typeName': 'City', + }, + { + 'name': 'Oakland', + 'operation': 'CREATE', + 'qualifiedName': 'city://Oakland', + 'relationships': None, + 'typeName': 'City', + }, + { + 'name': 'Top Gun', + 'operation': 'CREATE', + 'qualifiedName': 'movie://Top Gun', + 'relationships': 'actors#ACTOR#actor://Tom Cruise|actors#ACTOR#actor://Meg Ryan', + 'typeName': 'Movie', + }, + ] + + self.assertEqual(expected, actual) + + actual = [] + relation = movie.next_atlas_relation() + while relation: + actual.append(atlas_serializer.serialize_relationship(relation)) + relation = movie.next_atlas_relation() + + expected = [ + { + 'entityQualifiedName1': 'movie://Top Gun', + 'entityQualifiedName2': 'city://San Diego', + 'entityType1': 'Movie', + 'entityType2': 'City', + 'relationshipType': 'FILMED_AT', + }, + { + 'entityQualifiedName1': 'movie://Top Gun', + 'entityQualifiedName2': 'city://Oakland', + 'entityType1': 'Movie', + 'entityType2': 'City', + 'relationshipType': 'FILMED_AT', + }, + ] + self.assertEqual(expected, actual) + + +class Actor: + TYPE = 'Actor' + KEY_FORMAT = 'actor://{}' + + def __init__(self, name: str) -> None: + self.name = name + + +class City: + TYPE = 'City' + KEY_FORMAT = 'city://{}' + + def __init__(self, name: str) -> None: + self.name = name + + +class Movie(AtlasSerializable): + TYPE = 'Movie' + KEY_FORMAT = 'movie://{}' + MOVIE_ACTOR_RELATION_TYPE = 'ACTOR' + MOVIE_CITY_RELATION_TYPE = 'FILMED_AT' + + def __init__( + self, + name: str, + actors: Iterable[Actor], + cities: Iterable[City], + ) -> None: + self._name = name + self._actors = actors + self._cities = cities + self._entity_iter = iter(self._create_next_atlas_entity()) + self._relation_iter = iter(self._create_next_atlas_relation()) + + def create_next_atlas_entity(self) -> Union[AtlasEntity, None]: + try: + return next(self._entity_iter) + except StopIteration: + return None + + def create_next_atlas_relation(self) -> Union[AtlasRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def _create_next_atlas_entity(self) -> Iterable[AtlasEntity]: + + for actor in self._actors: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, actor.KEY_FORMAT.format(actor.name)), + ('name', actor.name), + ] + + actor_entity_attrs = {} + for attr in attrs_mapping: + attr_key, attr_value = attr + actor_entity_attrs[attr_key] = attr_value + + actor_entity = AtlasEntity( + typeName=actor.TYPE, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=actor_entity_attrs, + relationships=None, + ) + yield actor_entity + + for city in self._cities: + attrs_mapping = [ + (AtlasCommonParams.qualified_name, city.KEY_FORMAT.format(city.name)), + ('name', city.name), + ] + + city_entity_attrs = {} + for attr in attrs_mapping: + attr_key, attr_value = attr + city_entity_attrs[attr_key] = attr_value + + city_entity = AtlasEntity( + typeName=city.TYPE, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=city_entity_attrs, + relationships=None, + ) + yield city_entity + + attrs_mapping = [ + (AtlasCommonParams.qualified_name, self.KEY_FORMAT.format(self._name)), + ('name', self._name), + ] + + movie_entity_attrs = {} + for attr in attrs_mapping: + attr_key, attr_value = attr + movie_entity_attrs[attr_key] = attr_value + + relationship_list = [] + """ + relationship in form 'relation_attribute#relation_entity_type#qualified_name_of_related_object + """ + for actor in self._actors: + relationship_list.append( + AtlasSerializedEntityFields.relationships_kv_separator + .join(( + 'actors', + self.MOVIE_ACTOR_RELATION_TYPE, + actor.KEY_FORMAT.format(actor.name), + )), + ) + + movie_entity = AtlasEntity( + typeName=self.TYPE, + operation=AtlasSerializedEntityOperation.CREATE, + attributes=movie_entity_attrs, + relationships=AtlasSerializedEntityFields.relationships_separator.join(relationship_list), + ) + yield movie_entity + + def _create_next_atlas_relation(self) -> Iterator[AtlasRelationship]: + for city in self._cities: + city_relationship = AtlasRelationship( + relationshipType=self.MOVIE_CITY_RELATION_TYPE, + entityType1=self.TYPE, + entityQualifiedName1=self.KEY_FORMAT.format(self._name), + entityType2=city.TYPE, + entityQualifiedName2=city.KEY_FORMAT.format(city.name), + attributes={}, + ) + yield city_relationship + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_badge.py b/databuilder/tests/unit/models/test_badge.py new file mode 100644 index 0000000000..4e763525e1 --- /dev/null +++ b/databuilder/tests/unit/models/test_badge.py @@ -0,0 +1,252 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.badge import Badge, BadgeMetadata +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +db = 'hive' +SCHEMA = 'BASE' +TABLE = 'TEST' +CLUSTER = 'DEFAULT' +badge1 = Badge('badge1', 'column') +badge2 = Badge('badge2', 'column') + + +class TestBadge(unittest.TestCase): + def setUp(self) -> None: + super(TestBadge, self).setUp() + self.badge_metada = BadgeMetadata(start_label='Column', + start_key='hive://default.base/test/ds', + badges=[badge1, badge2]) + + def test_badge_name_category_are_lower_cases(self) -> None: + uppercase_badge = Badge('BadGe3', 'COLUMN_3') + self.assertEqual(uppercase_badge.name, 'badge3') + self.assertEqual(uppercase_badge.category, 'column_3') + + def test_get_badge_key(self) -> None: + badge_key = self.badge_metada.get_badge_key(badge1.name) + self.assertEqual(badge_key, badge1.name) + + def test_create_nodes(self) -> None: + node1 = { + NODE_KEY: BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge1.name), + NODE_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + BadgeMetadata.BADGE_CATEGORY: badge1.category + } + node2 = { + NODE_KEY: BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge2.name), + NODE_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + BadgeMetadata.BADGE_CATEGORY: badge2.category + } + expected = [node1, node2] + + actual = [] + node = self.badge_metada.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.badge_metada.create_next_node() + + self.assertEqual(expected, actual) + + def test_create_nodes_neptune(self) -> None: + actual = [] + node = self.badge_metada.create_next_node() + while node: + serialized_node = neptune_serializer.convert_node(node) + actual.append(serialized_node) + node = self.badge_metada.create_next_node() + node_key_1 = BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge1.name) + node_id_1 = BadgeMetadata.BADGE_NODE_LABEL + ":" + node_key_1 + expected_node1 = { + NEPTUNE_HEADER_ID: node_id_1, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: node_key_1, + NEPTUNE_HEADER_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + BadgeMetadata.BADGE_CATEGORY + ':String(single)': badge1.category + } + node_key_2 = BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge2.name) + node_id_2 = BadgeMetadata.BADGE_NODE_LABEL + ":" + node_key_2 + expected_node2 = { + NEPTUNE_HEADER_ID: node_id_2, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: node_key_2, + NEPTUNE_HEADER_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + BadgeMetadata.BADGE_CATEGORY + ':String(single)': badge2.category + } + expected = [expected_node1, expected_node2] + + self.assertEqual(expected, actual) + + def test_bad_entity_label(self) -> None: + user_label = 'User' + table_key = 'hive://default.base/test' + self.assertRaises(Exception, + BadgeMetadata, + start_label=user_label, + start_key=table_key, + badges=[badge1, badge2]) + + def test_create_relation(self) -> None: + actual = [] + relation = self.badge_metada.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.badge_metada.create_next_relation() + + relation1 = { + RELATION_START_LABEL: self.badge_metada.start_label, + RELATION_END_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + RELATION_START_KEY: self.badge_metada.start_key, + RELATION_END_KEY: BadgeMetadata.get_badge_key(badge1.name), + RELATION_TYPE: BadgeMetadata.BADGE_RELATION_TYPE, + RELATION_REVERSE_TYPE: BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + } + relation2 = { + RELATION_START_LABEL: self.badge_metada.start_label, + RELATION_END_LABEL: BadgeMetadata.BADGE_NODE_LABEL, + RELATION_START_KEY: self.badge_metada.start_key, + RELATION_END_KEY: BadgeMetadata.get_badge_key(badge2.name), + RELATION_TYPE: BadgeMetadata.BADGE_RELATION_TYPE, + RELATION_REVERSE_TYPE: BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + } + expected = [relation1, relation2] + + self.assertEqual(expected, actual) + + def test_create_relation_neptune(self) -> None: + actual = [] + relation = self.badge_metada.create_next_relation() + while relation: + serialized_relations = neptune_serializer.convert_relationship(relation) + actual.append(serialized_relations) + relation = self.badge_metada.create_next_relation() + + badge_id_1 = BadgeMetadata.BADGE_NODE_LABEL + ':' + BadgeMetadata.get_badge_key(badge1.name) + badge_id_2 = BadgeMetadata.BADGE_NODE_LABEL + ':' + BadgeMetadata.get_badge_key(badge2.name) + start_key = self.badge_metada.start_label + ':' + self.badge_metada.start_key + + neptune_forward_expected_1 = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=badge_id_1, + label=BadgeMetadata.BADGE_RELATION_TYPE, + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=badge_id_1, + label=BadgeMetadata.BADGE_RELATION_TYPE, + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: badge_id_1, + NEPTUNE_HEADER_LABEL: BadgeMetadata.BADGE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected_1 = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=badge_id_1, + to_vertex_id=start_key, + label=BadgeMetadata.INVERSE_BADGE_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=badge_id_1, + to_vertex_id=start_key, + label=BadgeMetadata.INVERSE_BADGE_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: badge_id_1, + NEPTUNE_RELATIONSHIP_HEADER_TO: start_key, + NEPTUNE_HEADER_LABEL: BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_forward_expected_2 = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=badge_id_2, + label=BadgeMetadata.BADGE_RELATION_TYPE, + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=badge_id_2, + label=BadgeMetadata.BADGE_RELATION_TYPE, + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: badge_id_2, + NEPTUNE_HEADER_LABEL: BadgeMetadata.BADGE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + + neptune_reversed_expected_2 = { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=badge_id_2, + to_vertex_id=start_key, + label=BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=badge_id_2, + to_vertex_id=start_key, + label=BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: badge_id_2, + NEPTUNE_RELATIONSHIP_HEADER_TO: start_key, + NEPTUNE_HEADER_LABEL: BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + expected = [[neptune_forward_expected_1, neptune_reversed_expected_1], + [neptune_forward_expected_2, neptune_reversed_expected_2]] + + self.assertEqual(expected, actual) + + def test_create_records(self) -> None: + expected = [ + { + 'rk': BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge1.name), + 'category': badge1.category + }, + { + 'column_rk': 'hive://default.base/test/ds', + 'badge_rk': BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge1.name) + }, + { + 'rk': BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge2.name), + 'category': badge2.category + }, + { + 'column_rk': 'hive://default.base/test/ds', + 'badge_rk': BadgeMetadata.BADGE_KEY_FORMAT.format(badge=badge2.name) + } + ] + + actual = [] + record = self.badge_metada.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.badge_metada.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_dashboard_elasticsearch_document.py b/databuilder/tests/unit/models/test_dashboard_elasticsearch_document.py new file mode 100644 index 0000000000..0d7c03c213 --- /dev/null +++ b/databuilder/tests/unit/models/test_dashboard_elasticsearch_document.py @@ -0,0 +1,55 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from databuilder.models.dashboard_elasticsearch_document import DashboardESDocument + + +class TestDashboardElasticsearchDocument(unittest.TestCase): + + def test_to_json(self) -> None: + """ + Test string generated from to_json method + """ + test_obj = DashboardESDocument(group_name='test_dashboard_group', + name='test_dashboard_name', + description='test_description', + product='mode', + cluster='gold', + group_description='work space group', + query_names=['query1'], + chart_names=['chart1'], + group_url='mode_group_url', + url='mode_report_url', + uri='mode_dashboard://gold.cluster/dashboard_group/dashboard', + last_successful_run_timestamp=10, + total_usage=10, + tags=['test'], + badges=['test_badge']) + + expected_document_dict = {"group_name": "test_dashboard_group", + "name": "test_dashboard_name", + "description": "test_description", + "product": "mode", + "cluster": "gold", + "group_url": "mode_group_url", + "url": "mode_report_url", + "uri": "mode_dashboard://gold.cluster/dashboard_group/dashboard", + "query_names": ['query1'], + "chart_names": ['chart1'], + "last_successful_run_timestamp": 10, + "group_description": "work space group", + "total_usage": 10, + "tags": ["test"], + "badges": ["test_badge"], + + } + + result = test_obj.to_json() + results = result.split("\n") + + # verify two new line characters in result + self.assertEqual(len(results), 2, "Result from to_json() function doesn't have a newline!") + self.assertDictEqual(json.loads(results[0]), expected_document_dict) diff --git a/databuilder/tests/unit/models/test_description_metadata.py b/databuilder/tests/unit/models/test_description_metadata.py new file mode 100644 index 0000000000..28a65823ac --- /dev/null +++ b/databuilder/tests/unit/models/test_description_metadata.py @@ -0,0 +1,124 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.description_metadata import DescriptionMetadata +from databuilder.serializers import neo4_serializer + + +class TestDescriptionMetadata(unittest.TestCase): + def test_raise_exception_when_missing_data(self) -> None: + # assert raise when missing description node key + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text').next_node + ) + DescriptionMetadata(text='test_text', description_key='test_key').next_node() + DescriptionMetadata(text='test_text', start_key='start_key').next_node() + + # assert raise when missing relation start label + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text', start_key='start_key').next_relation + ) + DescriptionMetadata(text='test_text', start_key='test_key', start_label='Table').next_relation() + + # assert raise when missing relation start key + self.assertRaises( + Exception, + DescriptionMetadata(text='test_text', description_key='test_key', start_label='Table').next_relation + ) + + def test_serialize_table_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 1', + start_label='Table', + start_key='test_start_key' + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 1', 'KEY': 'test_start_key/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Table', 'END_KEY': 'test_start_key/_description', + 'END_LABEL': 'Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected) + + def test_serialize_column_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 2', + start_label='Column', + start_key='test_start_key', + description_key='customized_key' + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 2', 'KEY': 'customized_key', + 'LABEL': 'Description', 'description_source': 'description'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Column', 'END_KEY': 'customized_key', + 'END_LABEL': 'Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected) + + def test_serialize_column_with_source_description_metadata(self) -> None: + description_metadata = DescriptionMetadata( + text='test text 3', + start_label='Column', + start_key='test_start_key', + description_key='customized_key', + source='external', + ) + node_row = description_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = description_metadata.next_node() + expected = [ + {'description': 'test text 3', 'KEY': 'customized_key', + 'LABEL': 'Programmatic_Description', 'description_source': 'external'}, + ] + self.assertEqual(actual, expected) + + relation_row = description_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = description_metadata.next_relation() + expected = [ + {'START_KEY': 'test_start_key', 'START_LABEL': 'Column', 'END_KEY': 'customized_key', + 'END_LABEL': 'Programmatic_Description', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} + ] + self.assertEqual(actual, expected) diff --git a/databuilder/tests/unit/models/test_es_last_updated.py b/databuilder/tests/unit/models/test_es_last_updated.py new file mode 100644 index 0000000000..6681b677c9 --- /dev/null +++ b/databuilder/tests/unit/models/test_es_last_updated.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.es_last_updated import ESLastUpdated +from databuilder.models.graph_serializable import NODE_KEY, NODE_LABEL +from databuilder.serializers import mysql_serializer, neo4_serializer + + +class TestNeo4jESLastUpdated(unittest.TestCase): + + def setUp(self) -> None: + super(TestNeo4jESLastUpdated, self).setUp() + self.es_last_updated = ESLastUpdated(timestamp=100) + + self.expected_node_results = [{ + NODE_KEY: 'amundsen_updated_timestamp', + NODE_LABEL: 'Updatedtimestamp', + 'latest_timestamp:UNQUOTED': 100, + }] + + def test_create_nodes(self) -> None: + actual = [] + node = self.es_last_updated.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.es_last_updated.create_next_node() + + self.assertEqual(actual, self.expected_node_results) + + def test_create_next_relation(self) -> None: + self.assertIs(self.es_last_updated.create_next_relation(), None) + + def test_create_records(self) -> None: + expected = [{ + 'rk': 'amundsen_updated_timestamp', + 'latest_timestamp': 100 + }] + + actual = [] + record = self.es_last_updated.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.es_last_updated.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_fixtures/__init__.py b/databuilder/tests/unit/models/test_fixtures/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/models/test_fixtures/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/models/test_fixtures/table_metadata_fixtures.py b/databuilder/tests/unit/models/test_fixtures/table_metadata_fixtures.py new file mode 100644 index 0000000000..687adf80d8 --- /dev/null +++ b/databuilder/tests/unit/models/test_fixtures/table_metadata_fixtures.py @@ -0,0 +1,1003 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import ANY + +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +EXPECTED_NEPTUNE_NODES = [ + { + 'name:String(single)': 'test_table1', + NEPTUNE_HEADER_ID: 'Table:hive://gold.test_schema1/test_table1', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'Table', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'is_view:Bool(single)': False + }, + { + 'description:String(single)': 'test_table1', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'sort_order:Long(single)': 0, + 'col_type:String(single)': 'bigint', + 'name:String(single)': 'test_id1', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/test_id1', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/test_id1', + NEPTUNE_HEADER_LABEL: 'Column' + }, + { + 'description:String(single)': 'description of test_table1', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/test_id1/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/test_id1/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'sort_order:Long(single)': 1, + 'col_type:String(single)': 'bigint', + 'name:String(single)': 'test_id2', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/test_id2', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/test_id2', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'description:String(single)': 'description of test_id2', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/test_id2/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/test_id2/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'sort_order:Long(single)': 2, + 'col_type:String(single)': 'boolean', + 'name:String(single)': 'is_active', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/is_active', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/is_active', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'sort_order:Long(single)': 3, + 'col_type:String(single)': 'varchar', + 'name:String(single)': 'source', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/source', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/source', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'description:String(single)': 'description of source', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/source/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/source/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'sort_order:Long(single)': 4, + 'col_type:String(single)': 'timestamp', + 'name:String(single)': 'etl_created_at', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/etl_created_at', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/etl_created_at', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'description:String(single)': 'description of etl_created_at', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'hive://gold.test_schema1/test_table1/etl_created_at/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'sort_order:Long(single)': 5, + 'col_type:String(single)': 'varchar', + 'name:String(single)': 'ds', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/ds', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/ds', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'sort_order:Long(single)': 6, + 'col_type:String(single)': 'array>>', + 'name:String(single)': 'has_nested_type', + NEPTUNE_HEADER_ID: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_HEADER_LABEL: 'Column', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'description:String(single)': 'column with nested types', + NEPTUNE_HEADER_ID: 'Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'hive://gold.test_schema1/test_table1/has_nested_type/_description', + NEPTUNE_HEADER_LABEL: 'Description', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'description_source:String(single)': 'description' + }, + { + 'kind:String(single)': 'array', + 'name:String(single)': 'has_nested_type', + 'data_type:String(single)': 'array>>', + NEPTUNE_HEADER_ID: 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + NEPTUNE_HEADER_LABEL: 'Type_Metadata', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'kind:String(single)': 'array', + 'name:String(single)': '_inner_', + 'data_type:String(single)': 'array>', + NEPTUNE_HEADER_ID: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + NEPTUNE_HEADER_LABEL: 'Type_Metadata', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'kind:String(single)': 'array', + 'name:String(single)': '_inner_', + 'data_type:String(single)': 'array', + NEPTUNE_HEADER_ID: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_/_inner_', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_/_inner_', + NEPTUNE_HEADER_LABEL: 'Type_Metadata', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'name:String(single)': 'hive', + NEPTUNE_HEADER_ID: 'Database:database://hive', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'database://hive', + NEPTUNE_HEADER_LABEL: 'Database', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'name:String(single)': 'gold', + NEPTUNE_HEADER_ID: 'Cluster:hive://gold', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold', + NEPTUNE_HEADER_LABEL: 'Cluster', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + }, + { + 'name:String(single)': 'test_schema1', + NEPTUNE_HEADER_ID: 'Schema:hive://gold.test_schema1', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.test_schema1', + NEPTUNE_HEADER_LABEL: 'Schema', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + } +] + +EXPECTED_RELATIONSHIPS_NEPTUNE = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Schema:hive://gold.test_schema1', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='TABLE' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Schema:hive://gold.test_schema1', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='TABLE' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Schema:hive://gold.test_schema1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'TABLE', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Schema:hive://gold.test_schema1', + label='TABLE_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Schema:hive://gold.test_schema1', + label='TABLE_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Schema:hive://gold.test_schema1', + NEPTUNE_HEADER_LABEL: 'TABLE_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:hive://gold.test_schema1/test_table1/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/_description', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/_description', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:hive://gold.test_schema1/test_table1/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/test_id1', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/test_id1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id1/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id1/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/test_id1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:hive://gold.test_schema1/test_table1/test_id1/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id1/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id1/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id1', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:hive://gold.test_schema1/test_table1/test_id1/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/test_id1', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/test_id2', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/test_id2', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id2/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id2/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/test_id2', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:hive://gold.test_schema1/test_table1/test_id2/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id2/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/test_id2/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/test_id2', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:hive://gold.test_schema1/test_table1/test_id2/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/test_id2', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/is_active', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/is_active', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/is_active', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/is_active', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/is_active', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/is_active', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/source', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/source', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/source/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/source/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/source', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Description:hive://gold.test_schema1/test_table1/source/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/source/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/source/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/source', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Description:hive://gold.test_schema1/test_table1/source/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/source', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/etl_created_at', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/etl_created_at', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/etl_created_at', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/etl_created_at', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Description:hive://gold.test_schema1/test_table1/etl_created_at/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/etl_created_at', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/ds', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/ds', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/ds', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/ds', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/ds', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/ds', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='COLUMN' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.test_schema1/test_table1', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='COLUMN' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_HEADER_LABEL: 'COLUMN', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Table:hive://gold.test_schema1/test_table1', + label='COLUMN_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.test_schema1/test_table1', + NEPTUNE_HEADER_LABEL: 'COLUMN_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + label='DESCRIPTION' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + label='DESCRIPTION' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='DESCRIPTION_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='DESCRIPTION_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Description:hive://gold.test_schema1/test_table1/has_nested_type/_description', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_HEADER_LABEL: 'DESCRIPTION_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + label='TYPE_METADATA' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + label='TYPE_METADATA' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + NEPTUNE_HEADER_LABEL: 'TYPE_METADATA', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='TYPE_METADATA_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type', + to_vertex_id='Column:hive://gold.test_schema1/test_table1/has_nested_type', + label='TYPE_METADATA_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.test_schema1/test_table1/has_nested_type', + NEPTUNE_HEADER_LABEL: 'TYPE_METADATA_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + label='SUBTYPE' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + label='SUBTYPE' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + NEPTUNE_HEADER_LABEL: 'SUBTYPE', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + label='SUBTYPE_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + label='SUBTYPE_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + NEPTUNE_HEADER_LABEL: 'SUBTYPE_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + label='SUBTYPE' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + label='SUBTYPE' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + NEPTUNE_HEADER_LABEL: 'SUBTYPE', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + label='SUBTYPE_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + to_vertex_id='Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_', + label='SUBTYPE_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type' + '/type/has_nested_type/_inner_/_inner_', + NEPTUNE_RELATIONSHIP_HEADER_TO: + 'Type_Metadata:hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + NEPTUNE_HEADER_LABEL: 'SUBTYPE_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Database:database://hive', + NEPTUNE_HEADER_ID: 'CLUSTER:Database:database://hive_Cluster:hive://gold', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'CLUSTER:Database:database://hive_Cluster:hive://gold', + NEPTUNE_HEADER_LABEL: 'CLUSTER', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Cluster:hive://gold' + }, + { + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Cluster:hive://gold', + NEPTUNE_HEADER_ID: 'CLUSTER_OF:Cluster:hive://gold_Database:database://hive', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'CLUSTER_OF:Cluster:hive://gold_Database:database://hive', + NEPTUNE_HEADER_LABEL: 'CLUSTER_OF', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Database:database://hive' + } + ], + [ + { + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Cluster:hive://gold', + NEPTUNE_HEADER_ID: 'SCHEMA:Cluster:hive://gold_Schema:hive://gold.test_schema1', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'SCHEMA:Cluster:hive://gold_Schema:hive://gold.test_schema1', + NEPTUNE_HEADER_LABEL: 'SCHEMA', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Schema:hive://gold.test_schema1' + }, + { + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Schema:hive://gold.test_schema1', + NEPTUNE_HEADER_ID: 'SCHEMA_OF:Schema:hive://gold.test_schema1_Cluster:hive://gold', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: + 'SCHEMA_OF:Schema:hive://gold.test_schema1_Cluster:hive://gold', + NEPTUNE_HEADER_LABEL: 'SCHEMA_OF', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Cluster:hive://gold' + } + ] +] + +EXPECTED_RECORDS_MYSQL = [ + { + 'rk': 'database://hive', + 'name': 'hive' + }, + { + 'rk': 'hive://gold', + 'name': 'gold', + 'database_rk': 'database://hive' + }, + { + 'rk': 'hive://gold.test_schema1', + 'name': 'test_schema1', + 'cluster_rk': 'hive://gold', + }, + { + 'rk': 'hive://gold.test_schema1/test_table1', + 'name': 'test_table1', + 'is_view': False, + 'schema_rk': 'hive://gold.test_schema1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/_description', + 'description': 'test_table1', + 'description_source': 'description', + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/test_id1', + 'name': 'test_id1', + 'type': 'bigint', + 'sort_order': 0, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/test_id1/_description', + 'description': 'description of test_table1', + 'description_source': 'description', + 'column_rk': 'hive://gold.test_schema1/test_table1/test_id1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/test_id2', + 'name': 'test_id2', + 'type': 'bigint', + 'sort_order': 1, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/test_id2/_description', + 'description': 'description of test_id2', + 'description_source': 'description', + 'column_rk': 'hive://gold.test_schema1/test_table1/test_id2' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/is_active', + 'name': 'is_active', + 'type': 'boolean', + 'sort_order': 2, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/source', + 'name': 'source', + 'type': 'varchar', + 'sort_order': 3, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/source/_description', + 'description': 'description of source', + 'description_source': 'description', + 'column_rk': 'hive://gold.test_schema1/test_table1/source' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/etl_created_at', + 'name': 'etl_created_at', + 'type': 'timestamp', + 'sort_order': 4, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/etl_created_at/_description', + 'description': 'description of etl_created_at', + 'description_source': 'description', + 'column_rk': 'hive://gold.test_schema1/test_table1/etl_created_at' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/ds', + 'name': 'ds', + 'type': 'varchar', + 'sort_order': 5, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/has_nested_type', + 'name': 'has_nested_type', + 'type': 'array>>', + 'sort_order': 6, + 'table_rk': 'hive://gold.test_schema1/test_table1' + }, + { + 'rk': 'hive://gold.test_schema1/test_table1/has_nested_type/_description', + 'description': 'column with nested types', + 'description_source': 'description', + 'column_rk': 'hive://gold.test_schema1/test_table1/has_nested_type' + } +] diff --git a/databuilder/tests/unit/models/test_graph_serializable.py b/databuilder/tests/unit/models/test_graph_serializable.py new file mode 100644 index 0000000000..3cb010ea2a --- /dev/null +++ b/databuilder/tests/unit/models/test_graph_serializable.py @@ -0,0 +1,163 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Iterable, Union + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.serializers import neo4_serializer + + +class TestSerialize(unittest.TestCase): + + def test_serialize(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + cities = [City('San Diego'), City('Oakland')] + movie = Movie('Top Gun', actors, cities) + + actual = [] + node_row = movie.next_node() + while node_row: + actual.append(neo4_serializer.serialize_node(node_row)) + node_row = movie.next_node() + + expected = [ + {'name': 'Top Gun', 'KEY': 'movie://Top Gun', 'LABEL': 'Movie'}, + {'name': 'Top Gun', 'KEY': 'actor://Tom Cruise', 'LABEL': 'Actor'}, + {'name': 'Top Gun', 'KEY': 'actor://Meg Ryan', 'LABEL': 'Actor'}, + {'name': 'Top Gun', 'KEY': 'city://San Diego', 'LABEL': 'City'}, + {'name': 'Top Gun', 'KEY': 'city://Oakland', 'LABEL': 'City'} + ] + self.assertEqual(expected, actual) + + actual = [] + relation_row = movie.next_relation() + while relation_row: + actual.append(neo4_serializer.serialize_relationship(relation_row)) + relation_row = movie.next_relation() + + expected = [ + {'END_KEY': 'actor://Tom Cruise', 'START_LABEL': 'Movie', + 'END_LABEL': 'Actor', 'START_KEY': 'movie://Top Gun', + 'TYPE': 'ACTOR', 'REVERSE_TYPE': 'ACTED_IN'}, + {'END_KEY': 'actor://Meg Ryan', 'START_LABEL': 'Movie', + 'END_LABEL': 'Actor', 'START_KEY': 'movie://Top Gun', + 'TYPE': 'ACTOR', 'REVERSE_TYPE': 'ACTED_IN'}, + {'END_KEY': 'city://San Diego', 'START_LABEL': 'Movie', + 'END_LABEL': 'City', 'START_KEY': 'movie://Top Gun', + 'TYPE': 'FILMED_AT', 'REVERSE_TYPE': 'APPEARS_IN'}, + {'END_KEY': 'city://Oakland', 'START_LABEL': 'Movie', + 'END_LABEL': 'City', 'START_KEY': 'movie://Top Gun', + 'TYPE': 'FILMED_AT', 'REVERSE_TYPE': 'APPEARS_IN'} + ] + self.assertEqual(expected, actual) + + +class Actor(object): + LABEL = 'Actor' + KEY_FORMAT = 'actor://{}' + + def __init__(self, name: str) -> None: + self.name = name + + +class City(object): + LABEL = 'City' + KEY_FORMAT = 'city://{}' + + def __init__(self, name: str) -> None: + self.name = name + + +class Movie(GraphSerializable): + LABEL = 'Movie' + KEY_FORMAT = 'movie://{}' + MOVIE_ACTOR_RELATION_TYPE = 'ACTOR' + ACTOR_MOVIE_RELATION_TYPE = 'ACTED_IN' + MOVIE_CITY_RELATION_TYPE = 'FILMED_AT' + CITY_MOVIE_RELATION_TYPE = 'APPEARS_IN' + + def __init__(self, + name: str, + actors: Iterable[Actor], + cities: Iterable[City]) -> None: + self._name = name + self._actors = actors + self._cities = cities + self._node_iter = iter(self.create_nodes()) + self._relation_iter = iter(self.create_relation()) + + def create_next_node(self) -> Union[GraphNode, None]: + try: + return next(self._node_iter) + except StopIteration: + return None + + def create_next_relation(self) -> Union[GraphRelationship, None]: + try: + return next(self._relation_iter) + except StopIteration: + return None + + def create_nodes(self) -> Iterable[GraphNode]: + result = [GraphNode( + key=Movie.KEY_FORMAT.format(self._name), + label=Movie.LABEL, + attributes={ + 'name': self._name + } + )] + + for actor in self._actors: + actor_node = GraphNode( + key=Actor.KEY_FORMAT.format(actor.name), + label=Actor.LABEL, + attributes={ + 'name': self._name + } + ) + result.append(actor_node) + + for city in self._cities: + city_node = GraphNode( + key=City.KEY_FORMAT.format(city.name), + label=City.LABEL, + attributes={ + 'name': self._name + } + ) + result.append(city_node) + return result + + def create_relation(self) -> Iterable[GraphRelationship]: + result = [] + for actor in self._actors: + movie_actor_relation = GraphRelationship( + start_key=Movie.KEY_FORMAT.format(self._name), + end_key=Actor.KEY_FORMAT.format(actor.name), + start_label=Movie.LABEL, + end_label=Actor.LABEL, + type=Movie.MOVIE_ACTOR_RELATION_TYPE, + reverse_type=Movie.ACTOR_MOVIE_RELATION_TYPE, + attributes={} + ) + result.append(movie_actor_relation) + + for city in self._cities: + city_movie_relation = GraphRelationship( + start_key=Movie.KEY_FORMAT.format(self._name), + end_key=City.KEY_FORMAT.format(city.name), + start_label=Movie.LABEL, + end_label=City.LABEL, + type=Movie.MOVIE_CITY_RELATION_TYPE, + reverse_type=Movie.CITY_MOVIE_RELATION_TYPE, + attributes={} + ) + result.append(city_movie_relation) + return result + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_owner.py b/databuilder/tests/unit/models/test_owner.py new file mode 100644 index 0000000000..c9779d95ce --- /dev/null +++ b/databuilder/tests/unit/models/test_owner.py @@ -0,0 +1,122 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.owner import Owner +from databuilder.serializers import mysql_serializer, neo4_serializer + + +class TestDashboardOwner(unittest.TestCase): + + def setUp(self) -> None: + self.owner = Owner( + start_label='Table', + start_key='the_key', + owner_emails=[ + ' Foo@bar.biz', # should be converted to 'foo@bar.biz' + 'moo@cow.farm', + ] + ) + + self.expected_nodes = [ + { + 'KEY': 'foo@bar.biz', + 'LABEL': 'User', + 'email': 'foo@bar.biz', + }, + { + 'KEY': 'moo@cow.farm', + 'LABEL': 'User', + 'email': 'moo@cow.farm', + }, + ] + + self.expected_relations = [ + { + 'START_LABEL': 'Table', + 'END_LABEL': 'User', + 'START_KEY': 'the_key', + 'END_KEY': 'foo@bar.biz', + 'TYPE': 'OWNER', + 'REVERSE_TYPE': 'OWNER_OF', + }, + { + 'START_LABEL': 'Table', + 'END_LABEL': 'User', + 'START_KEY': 'the_key', + 'END_KEY': 'moo@cow.farm', + 'TYPE': 'OWNER', + 'REVERSE_TYPE': 'OWNER_OF', + }, + ] + + self.expected_records = [ + { + 'rk': 'foo@bar.biz', + 'email': 'foo@bar.biz' + }, + { + 'table_rk': 'the_key', + 'user_rk': 'foo@bar.biz' + }, + { + 'rk': 'moo@cow.farm', + 'email': 'moo@cow.farm' + }, + { + 'table_rk': 'the_key', + 'user_rk': 'moo@cow.farm' + } + ] + + def test_not_ownable_label(self) -> None: + with self.assertRaises(Exception) as e: + Owner( + start_label='User', # users can't be owned by other users + start_key='user@user.us', + owner_emails=['another_user@user.us'] + ) + self.assertEqual(e.exception.args, ('owners for User are not supported',)) + + def test_owner_nodes(self) -> None: + node = self.owner.next_node() + actual = [] + while node: + node_serialized = neo4_serializer.serialize_node(node) + actual.append(node_serialized) + node = self.owner.next_node() + + self.assertEqual(actual, self.expected_nodes) + + def test_owner_relations(self) -> None: + actual = [] + relation = self.owner.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.owner.create_next_relation() + + self.assertEqual(actual, self.expected_relations) + + def test_owner_record(self) -> None: + actual = [] + record = self.owner.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.owner.create_next_record() + + self.assertEqual(actual, self.expected_records) + + def test_not_table_serializable(self) -> None: + feature_owner = Owner( + start_label='Feature', + start_key='feature://a/b/c', + owner_emails=['user@user.us'] + ) + with self.assertRaises(Exception) as e: + record = feature_owner.create_next_record() + while record: + record = feature_owner.create_next_record() + self.assertEqual(e.exception.args, ('Feature<>Owner relationship is not table serializable',)) diff --git a/databuilder/tests/unit/models/test_table_column_usage.py b/databuilder/tests/unit/models/test_table_column_usage.py new file mode 100644 index 0000000000..9659d6e280 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_column_usage.py @@ -0,0 +1,127 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import no_type_check +from unittest.mock import ANY + +from databuilder.models.table_column_usage import ColumnReader, TableColumnUsage +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestTableColumnUsage(unittest.TestCase): + + def setUp(self) -> None: + col_readers = [ + ColumnReader( + database='db', + cluster='gold', + schema='scm', + table='foo', + column='*', + user_email='john@example.com' + ) + ] + self.table_col_usage = TableColumnUsage(col_readers=col_readers) + + @no_type_check # mypy is somehow complaining on assignment on expected dict. + def test_serialize(self) -> None: + node_row = self.table_col_usage.next_node() + actual = [] + while node_row: + actual.append(neo4_serializer.serialize_node(node_row)) + node_row = self.table_col_usage.next_node() + + expected = [{'LABEL': 'User', + 'KEY': 'john@example.com', + 'email': 'john@example.com'}] + self.assertEqual(expected, actual) + + rel_row = self.table_col_usage.next_relation() + actual = [] + while rel_row: + actual.append(neo4_serializer.serialize_relationship(rel_row)) + rel_row = self.table_col_usage.next_relation() + + expected = [{'read_count:UNQUOTED': 1, 'END_KEY': 'john@example.com', 'START_LABEL': 'Table', + 'END_LABEL': 'User', 'START_KEY': 'db://gold.scm/foo', 'TYPE': 'READ_BY', 'REVERSE_TYPE': 'READ'}] + self.assertEqual(expected, actual) + + def test_neptune_serialize(self) -> None: + rel_row = self.table_col_usage.next_relation() + actual = [] + while rel_row: + actual.append(neptune_serializer.convert_relationship(rel_row)) + rel_row = self.table_col_usage.next_relation() + expected = [[ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:db://gold.scm/foo', + to_vertex_id='User:john@example.com', + label='READ_BY' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:db://gold.scm/foo', + to_vertex_id='User:john@example.com', + label='READ_BY' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:db://gold.scm/foo', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'User:john@example.com', + NEPTUNE_HEADER_LABEL: 'READ_BY', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'read_count:Long(single)': 1 + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='User:john@example.com', + to_vertex_id='Table:db://gold.scm/foo', + label='READ' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='User:john@example.com', + to_vertex_id='Table:db://gold.scm/foo', + label='READ' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'User:john@example.com', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:db://gold.scm/foo', + NEPTUNE_HEADER_LABEL: 'READ', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'read_count:Long(single)': 1 + } + ]] + self.maxDiff = None + self.assertListEqual(expected, actual) + + def test_mysql_serialize(self) -> None: + col_readers = [ColumnReader(database='db', cluster='gold', schema='scm', table='foo', column='*', + user_email='john@example.com')] + table_col_usage = TableColumnUsage(col_readers=col_readers) + + actual = [] + record = table_col_usage.next_record() + while record: + actual.append(mysql_serializer.serialize_record(record)) + record = table_col_usage.next_record() + + expected_user = {'rk': 'john@example.com', + 'email': 'john@example.com'} + expected_usage = {'table_rk': 'db://gold.scm/foo', + 'user_rk': 'john@example.com', + 'read_count': 1} + expected = [expected_user, expected_usage] + + self.assertEqual(expected, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_table_elasticsearch_document.py b/databuilder/tests/unit/models/test_table_elasticsearch_document.py new file mode 100644 index 0000000000..185f1a54e5 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_elasticsearch_document.py @@ -0,0 +1,55 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from databuilder.models.table_elasticsearch_document import TableESDocument + + +class TestTableElasticsearchDocument(unittest.TestCase): + + def test_to_json(self) -> None: + """ + Test string generated from to_json method + """ + test_obj = TableESDocument(database='test_database', + cluster='test_cluster', + schema='test_schema', + name='test_table', + key='test_table_key', + last_updated_timestamp=123456789, + description='test_table_description', + column_names=['test_col1', 'test_col2'], + column_descriptions=['test_description1', 'test_description2'], + total_usage=100, + unique_usage=10, + tags=['test'], + programmatic_descriptions=['test'], + badges=['badge1'], + schema_description='schema description') + + expected_document_dict = {"database": "test_database", + "cluster": "test_cluster", + "schema": "test_schema", + "name": "test_table", + "display_name": "test_schema.test_table", + "key": "test_table_key", + "last_updated_timestamp": 123456789, + "description": "test_table_description", + "column_names": ["test_col1", "test_col2"], + "column_descriptions": ["test_description1", "test_description2"], + "total_usage": 100, + "unique_usage": 10, + "tags": ["test"], + "programmatic_descriptions": ['test'], + "badges": ["badge1"], + 'schema_description': 'schema description' + } + + result = test_obj.to_json() + results = result.split("\n") + + # verify two new line characters in result + self.assertEqual(len(results), 2, "Result from to_json() function doesn't have a newline!") + self.assertDictEqual(json.loads(results[0]), expected_document_dict) diff --git a/databuilder/tests/unit/models/test_table_last_updated.py b/databuilder/tests/unit/models/test_table_last_updated.py new file mode 100644 index 0000000000..6763c1f81b --- /dev/null +++ b/databuilder/tests/unit/models/test_table_last_updated.py @@ -0,0 +1,164 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.models.table_last_updated import TableLastUpdated +from databuilder.models.timestamp import timestamp_constants +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestTableLastUpdated(unittest.TestCase): + + def setUp(self) -> None: + super(TestTableLastUpdated, self).setUp() + + self.tableLastUpdated = TableLastUpdated(table_name='test_table', + last_updated_time_epoch=25195665, + schema='default') + + self.expected_node_results = [{ + NODE_KEY: 'hive://gold.default/test_table/timestamp', + NODE_LABEL: 'Timestamp', + 'last_updated_timestamp:UNQUOTED': 25195665, + timestamp_constants.TIMESTAMP_PROPERTY + ":UNQUOTED": 25195665, + 'name': 'last_updated_timestamp' + }] + + self.expected_relation_results = [{ + RELATION_START_KEY: 'hive://gold.default/test_table', + RELATION_START_LABEL: 'Table', + RELATION_END_KEY: 'hive://gold.default/test_table/timestamp', + RELATION_END_LABEL: 'Timestamp', + RELATION_TYPE: 'LAST_UPDATED_AT', + RELATION_REVERSE_TYPE: 'LAST_UPDATED_TIME_OF' + }] + + def test_get_table_model_key(self) -> None: + table = self.tableLastUpdated.get_table_model_key() + self.assertEqual(table, 'hive://gold.default/test_table') + + def test_get_last_updated_model_key(self) -> None: + last_updated = self.tableLastUpdated.get_last_updated_model_key() + self.assertEqual(last_updated, 'hive://gold.default/test_table/timestamp') + + def test_create_nodes(self) -> None: + actual = [] + node = self.tableLastUpdated.create_next_node() + while node: + serialize_node = neo4_serializer.serialize_node(node) + actual.append(serialize_node) + node = self.tableLastUpdated.create_next_node() + + self.assertEqual(actual, self.expected_node_results) + + def test_create_nodes_neptune(self) -> None: + node_id = TableLastUpdated.LAST_UPDATED_NODE_LABEL + ":" + self.tableLastUpdated.get_last_updated_model_key() + expected_nodes = [{ + NEPTUNE_HEADER_ID: node_id, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: self.tableLastUpdated.get_last_updated_model_key(), + NEPTUNE_HEADER_LABEL: TableLastUpdated.LAST_UPDATED_NODE_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'name:String(single)': 'last_updated_timestamp', + 'last_updated_timestamp:Long(single)': 25195665, + timestamp_constants.TIMESTAMP_PROPERTY + ":Long(single)": 25195665, + }] + + actual = [] + next_node = self.tableLastUpdated.create_next_node() + while next_node: + next_node_serialized = neptune_serializer.convert_node(next_node) + actual.append(next_node_serialized) + next_node = self.tableLastUpdated.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_relation(self) -> None: + actual = [] + relation = self.tableLastUpdated.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.tableLastUpdated.create_next_relation() + + def test_create_relation_neptune(self) -> None: + actual = [] + next_relation = self.tableLastUpdated.create_next_relation() + while next_relation: + next_relation_serialized = neptune_serializer.convert_relationship(next_relation) + actual.append(next_relation_serialized) + next_relation = self.tableLastUpdated.create_next_relation() + + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.default/test_table', + to_vertex_id='Timestamp:hive://gold.default/test_table/timestamp', + label='LAST_UPDATED_AT' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:hive://gold.default/test_table', + to_vertex_id='Timestamp:hive://gold.default/test_table/timestamp', + label='LAST_UPDATED_AT' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:hive://gold.default/test_table', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Timestamp:hive://gold.default/test_table/timestamp', + NEPTUNE_HEADER_LABEL: 'LAST_UPDATED_AT', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Timestamp:hive://gold.default/test_table/timestamp', + to_vertex_id='Table:hive://gold.default/test_table', + label='LAST_UPDATED_TIME_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Timestamp:hive://gold.default/test_table/timestamp', + to_vertex_id='Table:hive://gold.default/test_table', + label='LAST_UPDATED_TIME_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Timestamp:hive://gold.default/test_table/timestamp', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:hive://gold.default/test_table', + NEPTUNE_HEADER_LABEL: 'LAST_UPDATED_TIME_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + self.assertEqual(actual, expected) + + def test_create_records(self) -> None: + expected = [{ + 'rk': 'hive://gold.default/test_table/timestamp', + 'last_updated_timestamp': 25195665, + 'timestamp': 25195665, + 'name': 'last_updated_timestamp', + 'table_rk': 'hive://gold.default/test_table' + }] + + actual = [] + record = self.tableLastUpdated.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.tableLastUpdated.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_table_lineage.py b/databuilder/tests/unit/models/test_table_lineage.py new file mode 100644 index 0000000000..007400371b --- /dev/null +++ b/databuilder/tests/unit/models/test_table_lineage.py @@ -0,0 +1,183 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.table_lineage import TableLineage +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +DB = 'hive' +SCHEMA = 'base' +TABLE = 'test' +CLUSTER = 'default' + + +class TestTableLineage(unittest.TestCase): + + def setUp(self) -> None: + super(TestTableLineage, self).setUp() + self.table_lineage = TableLineage(table_key=f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}', + downstream_deps=['hive://default.test_schema/test_table1', + 'hive://default.test_schema/test_table2']) + + self.start_key = f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}' + self.end_key1 = f'{DB}://{CLUSTER}.test_schema/test_table1' + self.end_key2 = f'{DB}://{CLUSTER}.test_schema/test_table2' + + def test_create_nodes(self) -> None: + actual = [] + node = self.table_lineage.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.table_lineage.create_next_node() + + self.assertEqual(len(actual), 0) + + def test_create_relation(self) -> None: + expected_relations = [ + { + RELATION_START_KEY: self.start_key, + RELATION_START_LABEL: 'Table', + RELATION_END_KEY: self.end_key1, + RELATION_END_LABEL: 'Table', + RELATION_TYPE: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + RELATION_REVERSE_TYPE: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + }, + { + RELATION_START_KEY: self.start_key, + RELATION_START_LABEL: 'Table', + RELATION_END_KEY: self.end_key2, + RELATION_END_LABEL: 'Table', + RELATION_TYPE: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + RELATION_REVERSE_TYPE: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + } + ] + + actual = [] + relation = self.table_lineage.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.table_lineage.create_next_relation() + + self.assertEqual(actual, expected_relations) + + def test_create_relation_neptune(self) -> None: + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.start_key, + to_vertex_id='Table:' + self.end_key1, + label=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.start_key, + to_vertex_id='Table:' + self.end_key1, + label=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:' + self.start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:' + self.end_key1, + NEPTUNE_HEADER_LABEL: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.end_key1, + to_vertex_id='Table:' + self.start_key, + label=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.end_key1, + to_vertex_id='Table:' + self.start_key, + label=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:' + self.end_key1, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:' + self.start_key, + NEPTUNE_HEADER_LABEL: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.start_key, + to_vertex_id='Table:' + self.end_key2, + label=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.start_key, + to_vertex_id='Table:' + self.end_key2, + label=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:' + self.start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:' + self.end_key2, + NEPTUNE_HEADER_LABEL: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.end_key2, + to_vertex_id='Table:' + self.start_key, + label=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Table:' + self.end_key2, + to_vertex_id='Table:' + self.start_key, + label=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Table:' + self.end_key2, + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Table:' + self.start_key, + NEPTUNE_HEADER_LABEL: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + actual = [] + relation = self.table_lineage.create_next_relation() + while relation: + serialized_relation = neptune_serializer.convert_relationship(relation) + actual.append(serialized_relation) + relation = self.table_lineage.create_next_relation() + + self.assertEqual(actual, expected) + + def test_create_records(self) -> None: + expected = [ + { + 'table_source_rk': self.start_key, + 'table_target_rk': self.end_key1 + }, + { + 'table_source_rk': self.start_key, + 'table_target_rk': self.end_key2, + } + ] + + actual = [] + record = self.table_lineage.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.table_lineage.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_table_metadata.py b/databuilder/tests/unit/models/test_table_metadata.py new file mode 100644 index 0000000000..030b831b24 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_metadata.py @@ -0,0 +1,426 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import copy +import unittest +from typing import Dict, List + +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.type_metadata import ArrayTypeMetadata, TypeMetadata +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from tests.unit.models.test_fixtures.table_metadata_fixtures import ( + EXPECTED_NEPTUNE_NODES, EXPECTED_RECORDS_MYSQL, EXPECTED_RELATIONSHIPS_NEPTUNE, +) + + +class TestTableMetadata(unittest.TestCase): + def setUp(self) -> None: + super(TestTableMetadata, self).setUp() + TableMetadata.serialized_nodes_keys = set() + TableMetadata.serialized_rels_keys = set() + + column_with_type_metadata = ColumnMetadata('has_nested_type', 'column with nested types', + 'array>>', 6) + column_with_type_metadata.set_column_key('hive://gold.test_schema1/test_table1/has_nested_type') + column_with_type_metadata.set_type_metadata(self._set_up_type_metadata(column_with_type_metadata)) + + self.table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1', + [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0), + ColumnMetadata('test_id2', 'description of test_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + column_with_type_metadata + ] + ) + + self.table_metadata2 = TableMetadata( + 'hive', + 'gold', + 'test_schema1', + 'test_table1', + 'test_table1', + [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0), + ColumnMetadata('test_id2', 'description of test_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5), + column_with_type_metadata + ] + ) + + def _set_up_type_metadata(self, parent_column: ColumnMetadata) -> TypeMetadata: + array_type_metadata = ArrayTypeMetadata( + name='has_nested_type', + parent=parent_column, + type_str='array>>' + ) + nested_array_type_metadata_level1 = ArrayTypeMetadata( + name='_inner_', + parent=array_type_metadata, + type_str='array>' + ) + nested_array_type_metadata_level2 = ArrayTypeMetadata( + name='_inner_', + parent=nested_array_type_metadata_level1, + type_str='array' + ) + + array_type_metadata.array_inner_type = nested_array_type_metadata_level1 + nested_array_type_metadata_level1.array_inner_type = nested_array_type_metadata_level2 + + return array_type_metadata + + def test_serialize(self) -> None: + self.expected_nodes_deduped = [ + {'name': 'test_table1', 'KEY': 'hive://gold.test_schema1/test_table1', 'LABEL': 'Table', + 'is_view:UNQUOTED': False}, + {'description': 'test_table1', 'KEY': 'hive://gold.test_schema1/test_table1/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'sort_order:UNQUOTED': 0, 'col_type': 'bigint', 'name': 'test_id1', + 'KEY': 'hive://gold.test_schema1/test_table1/test_id1', 'LABEL': 'Column'}, + {'description': 'description of test_table1', + 'KEY': 'hive://gold.test_schema1/test_table1/test_id1/_description', 'LABEL': 'Description', + 'description_source': 'description'}, + {'sort_order:UNQUOTED': 1, 'col_type': 'bigint', 'name': 'test_id2', + 'KEY': 'hive://gold.test_schema1/test_table1/test_id2', 'LABEL': 'Column'}, + {'description': 'description of test_id2', + 'KEY': 'hive://gold.test_schema1/test_table1/test_id2/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'sort_order:UNQUOTED': 2, 'col_type': 'boolean', 'name': 'is_active', + 'KEY': 'hive://gold.test_schema1/test_table1/is_active', 'LABEL': 'Column'}, + {'sort_order:UNQUOTED': 3, 'col_type': 'varchar', 'name': 'source', + 'KEY': 'hive://gold.test_schema1/test_table1/source', 'LABEL': 'Column'}, + {'description': 'description of source', 'KEY': 'hive://gold.test_schema1/test_table1/source/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'sort_order:UNQUOTED': 4, 'col_type': 'timestamp', 'name': 'etl_created_at', + 'KEY': 'hive://gold.test_schema1/test_table1/etl_created_at', 'LABEL': 'Column'}, + {'description': 'description of etl_created_at', + 'KEY': 'hive://gold.test_schema1/test_table1/etl_created_at/_description', 'LABEL': 'Description', + 'description_source': 'description'}, + {'sort_order:UNQUOTED': 5, 'col_type': 'varchar', 'name': 'ds', + 'KEY': 'hive://gold.test_schema1/test_table1/ds', 'LABEL': 'Column'}, + {'sort_order:UNQUOTED': 6, 'col_type': 'array>>', + 'name': 'has_nested_type', 'KEY': 'hive://gold.test_schema1/test_table1/has_nested_type', + 'LABEL': 'Column'}, + {'description': 'column with nested types', + 'KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/_description', 'LABEL': 'Description', + 'description_source': 'description'}, + {'kind': 'array', 'name': 'has_nested_type', 'LABEL': 'Type_Metadata', + 'data_type': 'array>>', + 'KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type'}, + {'kind': 'array', 'name': '_inner_', 'LABEL': 'Type_Metadata', 'data_type': 'array>', + 'KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_'}, + {'kind': 'array', 'name': '_inner_', 'LABEL': 'Type_Metadata', 'data_type': 'array', + 'KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_/_inner_'} + ] + + self.expected_nodes = copy.deepcopy(self.expected_nodes_deduped) + self.expected_nodes.append({'name': 'hive', 'KEY': 'database://hive', 'LABEL': 'Database'}) + self.expected_nodes.append({'name': 'gold', 'KEY': 'hive://gold', 'LABEL': 'Cluster'}) + self.expected_nodes.append({'name': 'test_schema1', 'KEY': 'hive://gold.test_schema1', 'LABEL': 'Schema'}) + + self.expected_rels_deduped = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1', 'START_LABEL': 'Schema', 'END_LABEL': 'Table', + 'START_KEY': 'hive://gold.test_schema1', 'TYPE': 'TABLE', 'REVERSE_TYPE': 'TABLE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/_description', 'START_LABEL': 'Table', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/test_id1', 'START_LABEL': 'Table', + 'END_LABEL': 'Column', 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', + 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/test_id1/_description', 'START_LABEL': 'Column', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1/test_id1', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/test_id2', 'START_LABEL': 'Table', 'END_LABEL': 'Column', + 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/test_id2/_description', 'START_LABEL': 'Column', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1/test_id2', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/is_active', 'START_LABEL': 'Table', 'END_LABEL': 'Column', + 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/source', 'START_LABEL': 'Table', 'END_LABEL': 'Column', + 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/source/_description', 'START_LABEL': 'Column', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1/source', + 'TYPE': 'DESCRIPTION', + 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/etl_created_at', 'START_LABEL': 'Table', + 'END_LABEL': 'Column', 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', + 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/etl_created_at/_description', 'START_LABEL': 'Column', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1/etl_created_at', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/ds', 'START_LABEL': 'Table', 'END_LABEL': 'Column', + 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type', 'START_LABEL': 'Table', + 'END_LABEL': 'Column', 'START_KEY': 'hive://gold.test_schema1/test_table1', 'TYPE': 'COLUMN', + 'REVERSE_TYPE': 'COLUMN_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/_description', 'START_LABEL': 'Column', + 'END_LABEL': 'Description', 'START_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + 'START_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Column', 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/has_nested_type/type/has_nested_type/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + self.expected_rels = copy.deepcopy(self.expected_rels_deduped) + self.expected_rels.append({'END_KEY': 'hive://gold', 'START_LABEL': 'Database', 'END_LABEL': 'Cluster', + 'START_KEY': 'database://hive', 'TYPE': 'CLUSTER', 'REVERSE_TYPE': 'CLUSTER_OF'}) + self.expected_rels.append({'END_KEY': 'hive://gold.test_schema1', 'START_LABEL': 'Cluster', + 'END_LABEL': 'Schema', 'START_KEY': 'hive://gold', + 'TYPE': 'SCHEMA', 'REVERSE_TYPE': 'SCHEMA_OF'}) + + node_row = self.table_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata.next_node() + for i in range(0, len(self.expected_nodes)): + self.assertEqual(actual[i], self.expected_nodes[i]) + + relation_row = self.table_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = self.table_metadata.next_relation() + for i in range(0, len(self.expected_rels)): + self.assertEqual(actual[i], self.expected_rels[i]) + + # 2nd record should not show already serialized database, cluster, and schema + node_row = self.table_metadata2.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata2.next_node() + + self.assertEqual(self.expected_nodes_deduped, actual) + + relation_row = self.table_metadata2.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = self.table_metadata2.next_relation() + + self.assertEqual(self.expected_rels_deduped, actual) + + def test_serialize_neptune(self) -> None: + node_row = self.table_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neptune_serializer.convert_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata.next_node() + + self.assertEqual(EXPECTED_NEPTUNE_NODES, actual) + + relation_row = self.table_metadata.next_relation() + neptune_actual: List[List[Dict]] = [] + while relation_row: + relation_row_serialized = neptune_serializer.convert_relationship(relation_row) + neptune_actual.append(relation_row_serialized) + relation_row = self.table_metadata.next_relation() + self.maxDiff = None + self.assertEqual(EXPECTED_RELATIONSHIPS_NEPTUNE, neptune_actual) + + def test_serialize_mysql(self) -> None: + actual = [] + record = self.table_metadata.next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.table_metadata.next_record() + + self.assertEqual(EXPECTED_RECORDS_MYSQL, actual) + + def test_table_attributes(self) -> None: + self.table_metadata3 = TableMetadata('hive', 'gold', 'test_schema3', 'test_table3', 'test_table3', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0), + ColumnMetadata('test_id2', 'description of test_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], is_view=False, attr1='uri', attr2='attr2') + + node_row = self.table_metadata3.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata3.next_node() + + self.assertEqual(actual[0].get('attr1'), 'uri') + self.assertEqual(actual[0].get('attr2'), 'attr2') + + # TODO NO test can run before serialiable... need to fix + def test_z_custom_sources(self) -> None: + self.custom_source = TableMetadata('hive', 'gold', 'test_schema3', 'test_table4', 'test_table4', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0), + ColumnMetadata('test_id2', 'description of test_id2', 'bigint', 1), + ColumnMetadata('is_active', None, 'boolean', 2), + ColumnMetadata('source', 'description of source', 'varchar', 3), + ColumnMetadata('etl_created_at', 'description of etl_created_at', 'timestamp', 4), + ColumnMetadata('ds', None, 'varchar', 5)], is_view=False, description_source="custom") + + node_row = self.custom_source.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.custom_source.next_node() + expected = {'LABEL': 'Programmatic_Description', + 'KEY': 'hive://gold.test_schema3/test_table4/_custom_description', + 'description_source': 'custom', 'description': 'test_table4'} + self.assertEqual(actual[1], expected) + + def test_tags_field(self) -> None: + self.table_metadata4 = TableMetadata('hive', 'gold', 'test_schema4', 'test_table4', 'test_table4', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0)], + is_view=False, tags=['tag1', 'tag2'], attr1='uri', attr2='attr2') + + node_row = self.table_metadata4.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata4.next_node() + + self.assertEqual(actual[0].get('attr1'), 'uri') + self.assertEqual(actual[0].get('attr2'), 'attr2') + + self.assertEqual(actual[2].get('LABEL'), 'Tag') + self.assertEqual(actual[2].get('KEY'), 'tag1') + self.assertEqual(actual[3].get('KEY'), 'tag2') + + relation_row = self.table_metadata4.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = self.table_metadata4.next_relation() + + # Table tag relationship + expected_tab_tag_rel1 = {'END_KEY': 'tag1', 'START_LABEL': 'Table', 'END_LABEL': + 'Tag', 'START_KEY': 'hive://gold.test_schema4/test_table4', + 'TYPE': 'TAGGED_BY', 'REVERSE_TYPE': 'TAG'} + expected_tab_tag_rel2 = {'END_KEY': 'tag2', 'START_LABEL': 'Table', + 'END_LABEL': 'Tag', 'START_KEY': 'hive://gold.test_schema4/test_table4', + 'TYPE': 'TAGGED_BY', 'REVERSE_TYPE': 'TAG'} + + self.assertEqual(actual[2], expected_tab_tag_rel1) + self.assertEqual(actual[3], expected_tab_tag_rel2) + + def test_col_badge_field(self) -> None: + self.table_metadata4 = TableMetadata('hive', 'gold', 'test_schema4', 'test_table4', 'test_table4', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0, ['col-badge1', 'col-badge2'])], + is_view=False, attr1='uri', attr2='attr2') + + node_row = self.table_metadata4.next_node() + actual = [] + while node_row: + serialized_node_row = neo4_serializer.serialize_node(node_row) + actual.append(serialized_node_row) + node_row = self.table_metadata4.next_node() + + self.assertEqual(actual[4].get('KEY'), 'col-badge1') + self.assertEqual(actual[5].get('KEY'), 'col-badge2') + + relation_row = self.table_metadata4.next_relation() + actual = [] + while relation_row: + serialized_relation_row = neo4_serializer.serialize_relationship(relation_row) + actual.append(serialized_relation_row) + relation_row = self.table_metadata4.next_relation() + + expected_col_badge_rel1 = {'END_KEY': 'col-badge1', 'START_LABEL': 'Column', + 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema4/test_table4/test_id1', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'} + expected_col_badge_rel2 = {'END_KEY': 'col-badge2', 'START_LABEL': 'Column', + 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema4/test_table4/test_id1', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'} + + self.assertEqual(actual[4], expected_col_badge_rel1) + self.assertEqual(actual[5], expected_col_badge_rel2) + + def test_tags_populated_from_str(self) -> None: + self.table_metadata5 = TableMetadata('hive', 'gold', 'test_schema5', 'test_table5', 'test_table5', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0)], tags="tag3, tag4") + + # Test table tag field populated from str + node_row = self.table_metadata5.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = self.table_metadata5.next_node() + + self.assertEqual(actual[2].get('LABEL'), 'Tag') + self.assertEqual(actual[2].get('KEY'), 'tag3') + self.assertEqual(actual[3].get('KEY'), 'tag4') + + relation_row = self.table_metadata5.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) + relation_row = self.table_metadata5.next_relation() + + # Table tag relationship + expected_tab_tag_rel3 = {'END_KEY': 'tag3', 'START_LABEL': 'Table', 'END_LABEL': + 'Tag', 'START_KEY': 'hive://gold.test_schema5/test_table5', + 'TYPE': 'TAGGED_BY', 'REVERSE_TYPE': 'TAG'} + expected_tab_tag_rel4 = {'END_KEY': 'tag4', 'START_LABEL': 'Table', + 'END_LABEL': 'Tag', 'START_KEY': 'hive://gold.test_schema5/test_table5', + 'TYPE': 'TAGGED_BY', 'REVERSE_TYPE': 'TAG'} + self.assertEqual(actual[2], expected_tab_tag_rel3) + self.assertEqual(actual[3], expected_tab_tag_rel4) + + def test_tags_arent_populated_from_empty_list_and_str(self) -> None: + self.table_metadata6 = TableMetadata('hive', 'gold', 'test_schema6', 'test_table6', 'test_table6', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0)], tags=[]) + + self.table_metadata7 = TableMetadata('hive', 'gold', 'test_schema7', 'test_table7', 'test_table7', [ + ColumnMetadata('test_id1', 'description of test_table1', 'bigint', 0)], tags="") + + # Test table tag fields are not populated from empty List + node_row = self.table_metadata6.next_node() + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + self.assertNotEqual(node_row_serialized.get('LABEL'), 'Tag') + node_row = self.table_metadata6.next_node() + + # Test table tag fields are not populated from empty str + node_row = self.table_metadata7.next_node() + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + self.assertNotEqual(node_row_serialized.get('LABEL'), 'Tag') + node_row = self.table_metadata7.next_node() + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_table_owner.py b/databuilder/tests/unit/models/test_table_owner.py new file mode 100644 index 0000000000..a78b34e2dc --- /dev/null +++ b/databuilder/tests/unit/models/test_table_owner.py @@ -0,0 +1,261 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.models.owner_constants import OWNER_OF_OBJECT_RELATION_TYPE, OWNER_RELATION_TYPE +from databuilder.models.table_owner import TableOwner +from databuilder.models.user import User +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +db = 'hive' +SCHEMA = 'BASE' +TABLE = 'TEST' +CLUSTER = 'DEFAULT' +TABLE_KEY = 'hive://DEFAULT.BASE/TEST' +owner1 = 'user1@1' +owner2 = 'user2@2' + + +class TestTableOwner(unittest.TestCase): + + def setUp(self) -> None: + super(TestTableOwner, self).setUp() + self.table_owner = TableOwner(db_name='hive', + schema=SCHEMA, + table_name=TABLE, + cluster=CLUSTER, + owners="user1@1, user2@2 ") + + def test_create_nodes(self) -> None: + expected_node1 = { + NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner1), + NODE_LABEL: User.USER_NODE_LABEL, + User.USER_NODE_EMAIL: owner1 + } + expected_node2 = { + NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner2), + NODE_LABEL: User.USER_NODE_LABEL, + User.USER_NODE_EMAIL: owner2 + } + expected = [expected_node1, expected_node2] + + actual = [] + node = self.table_owner.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.table_owner.create_next_node() + + self.assertEqual(actual, expected) + + def test_create_nodes_neptune(self) -> None: + expected_node1 = { + NEPTUNE_HEADER_ID: "User:" + User.USER_NODE_KEY_FORMAT.format(email=owner1), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: User.USER_NODE_KEY_FORMAT.format(email=owner1), + NEPTUNE_HEADER_LABEL: User.USER_NODE_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + User.USER_NODE_EMAIL + ":String(single)": owner1 + } + expected_node2 = { + NEPTUNE_HEADER_ID: "User:" + User.USER_NODE_KEY_FORMAT.format(email=owner2), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: User.USER_NODE_KEY_FORMAT.format(email=owner2), + NEPTUNE_HEADER_LABEL: User.USER_NODE_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + User.USER_NODE_EMAIL + ":String(single)": owner2 + } + expected = [expected_node1, expected_node2] + + actual = [] + node = self.table_owner.create_next_node() + while node: + serialized_node = neptune_serializer.convert_node(node) + actual.append(serialized_node) + node = self.table_owner.create_next_node() + + self.assertEqual(actual, expected) + + def test_create_relation(self) -> None: + expected_relation1 = { + RELATION_START_KEY: TABLE_KEY, + RELATION_START_LABEL: 'Table', + RELATION_END_KEY: owner1, + RELATION_END_LABEL: User.USER_NODE_LABEL, + RELATION_TYPE: OWNER_RELATION_TYPE, + RELATION_REVERSE_TYPE: OWNER_OF_OBJECT_RELATION_TYPE, + } + expected_relation2 = { + RELATION_START_KEY: TABLE_KEY, + RELATION_START_LABEL: 'Table', + RELATION_END_KEY: owner2, + RELATION_END_LABEL: User.USER_NODE_LABEL, + RELATION_TYPE: OWNER_RELATION_TYPE, + RELATION_REVERSE_TYPE: OWNER_OF_OBJECT_RELATION_TYPE, + } + expected = [expected_relation1, expected_relation2] + + actual = [] + relation = self.table_owner.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.table_owner.create_next_relation() + + self.assertEqual(actual, expected) + + def test_create_relation_neptune(self) -> None: + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + TABLE_KEY, + to_vertex_id="User:" + owner1, + label=OWNER_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + TABLE_KEY, + to_vertex_id="User:" + owner1, + label=OWNER_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Table:" + TABLE_KEY, + NEPTUNE_RELATIONSHIP_HEADER_TO: "User:" + owner1, + NEPTUNE_HEADER_LABEL: OWNER_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="User:" + owner1, + to_vertex_id="Table:" + TABLE_KEY, + label=OWNER_OF_OBJECT_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="User:" + owner1, + to_vertex_id="Table:" + TABLE_KEY, + label=OWNER_OF_OBJECT_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "User:" + owner1, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Table:" + TABLE_KEY, + NEPTUNE_HEADER_LABEL: OWNER_OF_OBJECT_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + ], + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + TABLE_KEY, + to_vertex_id="User:" + owner2, + label=OWNER_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + TABLE_KEY, + to_vertex_id="User:" + owner2, + label=OWNER_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Table:" + TABLE_KEY, + NEPTUNE_RELATIONSHIP_HEADER_TO: "User:" + owner2, + NEPTUNE_HEADER_LABEL: OWNER_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="User:" + owner2, + to_vertex_id="Table:" + TABLE_KEY, + label=OWNER_OF_OBJECT_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="User:" + owner2, + to_vertex_id="Table:" + TABLE_KEY, + label=OWNER_OF_OBJECT_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "User:" + owner2, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Table:" + TABLE_KEY, + NEPTUNE_HEADER_LABEL: OWNER_OF_OBJECT_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + ] + ] + + actual = [] + relation = self.table_owner.create_next_relation() + while relation: + serialized_relation = neptune_serializer.convert_relationship(relation) + actual.append(serialized_relation) + relation = self.table_owner.create_next_relation() + + self.assertEqual(expected, actual) + + def test_create_records(self) -> None: + expected = [ + { + 'rk': User.USER_NODE_KEY_FORMAT.format(email=owner1), + 'email': owner1 + }, + { + 'table_rk': TABLE_KEY, + 'user_rk': owner1 + }, + { + 'rk': User.USER_NODE_KEY_FORMAT.format(email=owner2), + 'email': owner2 + }, + { + 'table_rk': TABLE_KEY, + 'user_rk': owner2 + } + ] + + actual = [] + record = self.table_owner.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.table_owner.create_next_record() + + self.assertEqual(actual, expected) + + def test_create_nodes_with_owners_list(self) -> None: + self.table_owner_list = TableOwner(db_name='hive', + schema=SCHEMA, + table_name=TABLE, + cluster=CLUSTER, + owners=['user1@1', ' user2@2 ']) + expected_node1 = { + NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner1), + NODE_LABEL: User.USER_NODE_LABEL, + User.USER_NODE_EMAIL: owner1 + } + expected_node2 = { + NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner2), + NODE_LABEL: User.USER_NODE_LABEL, + User.USER_NODE_EMAIL: owner2 + } + expected = [expected_node1, expected_node2] + + actual = [] + node = self.table_owner_list.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.table_owner_list.create_next_node() + + self.assertEqual(actual, expected) diff --git a/databuilder/tests/unit/models/test_table_serializable.py b/databuilder/tests/unit/models/test_table_serializable.py new file mode 100644 index 0000000000..550df77606 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_serializable.py @@ -0,0 +1,131 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import ( + Iterable, Iterator, Union, +) + +from amundsen_rds.models import RDSModel +from sqlalchemy import ( + BigInteger, Column, ForeignKey, String, +) +from sqlalchemy.ext.declarative import declarative_base + +from databuilder.models.table_serializable import TableSerializable +from databuilder.serializers import mysql_serializer + +Base = declarative_base() + + +class TestTableSerializable(unittest.TestCase): + + def test_table_serializable(self) -> None: + actors = [Actor('Tom Cruise'), Actor('Meg Ryan')] + movie = Movie('Top Gun', actors) + + actual = [] + node_row = movie.next_record() + while node_row: + actual.append(mysql_serializer.serialize_record(node_row)) + node_row = movie.next_record() + + expected = [ + { + 'rk': 'movie://Top Gun', + 'name': 'Top Gun' + }, + { + 'rk': 'actor://Tom Cruise', + 'name': 'Tom Cruise' + }, + { + 'movie_rk': 'movie://Top Gun', + 'actor_rk': 'actor://Tom Cruise' + }, + { + 'rk': 'actor://Meg Ryan', + 'name': 'Meg Ryan' + }, + { + 'movie_rk': 'movie://Top Gun', + 'actor_rk': 'actor://Meg Ryan' + } + ] + + self.assertEqual(expected, actual) + + +class RDSMovie(Base): # type: ignore + __tablename__ = 'movie' + + rk = Column(String(128), primary_key=True) + name = Column(String(128)) + published_tag = Column(String(128), nullable=False) + publisher_last_updated_epoch_ms = Column(BigInteger, nullable=False) + + +class RDSActor(Base): # type: ignore + __tablename__ = 'actor' + + rk = Column(String(128), primary_key=True) + name = Column(String(128)) + published_tag = Column(String(128), nullable=False) + publisher_last_updated_epoch_ms = Column(BigInteger, nullable=False) + + +class RDSMovieActor(Base): # type: ignore + __tablename__ = 'movie_actor' + + movie_rk = Column(String(128), ForeignKey('movie.rk'), primary_key=True) + actor_rk = Column(String(128), ForeignKey('actor.rk'), primary_key=True) + published_tag = Column(String(128), nullable=False) + publisher_last_updated_epoch_ms = Column(BigInteger, nullable=False) + + +class Actor(object): + KEY_FORMAT = 'actor://{}' + + def __init__(self, name: str) -> None: + self.name = name + + +class Movie(TableSerializable): + KEY_FORMAT = 'movie://{}' + + def __init__(self, + name: str, + actors: Iterable[Actor]) -> None: + self._name = name + self._actors = actors + self._record_iter = self._create_record_iterator() + + def create_next_record(self) -> Union[RDSModel, None]: + try: + return next(self._record_iter) + except StopIteration: + return None + + def _create_record_iterator(self) -> Iterator[RDSModel]: + movie_record = RDSMovie( + rk=Movie.KEY_FORMAT.format(self._name), + name=self._name + ) + yield movie_record + + for actor in self._actors: + actor_record = RDSActor( + rk=Actor.KEY_FORMAT.format(actor.name), + name=actor.name + ) + yield actor_record + + movie_actor_record = RDSMovieActor( + movie_rk=Movie.KEY_FORMAT.format(self._name), + actor_rk=Actor.KEY_FORMAT.format(actor.name) + ) + yield movie_actor_record + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_table_source.py b/databuilder/tests/unit/models/test_table_source.py new file mode 100644 index 0000000000..22fd7ff1e3 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_source.py @@ -0,0 +1,150 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.table_source import TableSource +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +DB = 'hive' +SCHEMA = 'base' +TABLE = 'test' +CLUSTER = 'default' +SOURCE = '/etl/sql/file.py' + + +class TestTableSource(unittest.TestCase): + + def setUp(self) -> None: + super(TestTableSource, self).setUp() + self.table_source = TableSource(db_name='hive', + schema=SCHEMA, + table_name=TABLE, + cluster=CLUSTER, + source=SOURCE) + + self.start_key = f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}/_source' + self.end_key = f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}' + + def test_get_source_model_key(self) -> None: + source = self.table_source.get_source_model_key() + self.assertEqual(source, f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}/_source') + + def test_get_metadata_model_key(self) -> None: + metadata = self.table_source.get_metadata_model_key() + self.assertEqual(metadata, 'hive://default.base/test') + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': 'Source', + 'KEY': f'{DB}://{CLUSTER}.{SCHEMA}/{TABLE}/_source', + 'source': SOURCE, + 'source_type': 'github' + }] + + actual = [] + node = self.table_source.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.table_source.create_next_node() + + self.assertEqual(expected_nodes, actual) + + def test_create_relation(self) -> None: + expected_relations = [{ + RELATION_START_KEY: self.start_key, + RELATION_START_LABEL: TableSource.LABEL, + RELATION_END_KEY: self.end_key, + RELATION_END_LABEL: 'Table', + RELATION_TYPE: TableSource.SOURCE_TABLE_RELATION_TYPE, + RELATION_REVERSE_TYPE: TableSource.TABLE_SOURCE_RELATION_TYPE + }] + + actual = [] + relation = self.table_source.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.table_source.create_next_relation() + + self.assertEqual(expected_relations, actual) + + def test_create_relation_neptune(self) -> None: + actual = [] + relation = self.table_source.create_next_relation() + while relation: + serialized_relation = neptune_serializer.convert_relationship(relation) + actual.append(serialized_relation) + relation = self.table_source.create_next_relation() + + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Source:" + self.start_key, + to_vertex_id="Table:" + self.end_key, + label=TableSource.SOURCE_TABLE_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Source:" + self.start_key, + to_vertex_id="Table:" + self.end_key, + label=TableSource.SOURCE_TABLE_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Source:" + self.start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Table:" + self.end_key, + NEPTUNE_HEADER_LABEL: TableSource.SOURCE_TABLE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + self.end_key, + to_vertex_id="Source:" + self.start_key, + label=TableSource.TABLE_SOURCE_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + self.end_key, + to_vertex_id="Source:" + self.start_key, + label=TableSource.TABLE_SOURCE_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Table:" + self.end_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Source:" + self.start_key, + NEPTUNE_HEADER_LABEL: TableSource.TABLE_SOURCE_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + self.assertListEqual(expected, actual) + + def test_create_records(self) -> None: + expected = [{ + 'rk': self.table_source.get_source_model_key(), + 'source': self.table_source.source, + 'source_type': self.table_source.source_type, + 'table_rk': self.table_source.get_metadata_model_key() + }] + + actual = [] + record = self.table_source.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.table_source.create_next_record() + + self.assertEqual(expected, actual) diff --git a/databuilder/tests/unit/models/test_table_stats.py b/databuilder/tests/unit/models/test_table_stats.py new file mode 100644 index 0000000000..6eca9c7719 --- /dev/null +++ b/databuilder/tests/unit/models/test_table_stats.py @@ -0,0 +1,178 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.models.table_stats import TableColumnStats +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestTableStats(unittest.TestCase): + + def setUp(self) -> None: + super(TestTableStats, self).setUp() + self.table_stats = TableColumnStats(table_name='base.test', + col_name='col', + stat_name='avg', + stat_val='1', + start_epoch='1', + end_epoch='2',) + + self.expected_node_results = [{ + NODE_KEY: 'hive://gold.base/test/col/avg/', + NODE_LABEL: 'Stat', + 'stat_val': '1', + 'stat_type': 'avg', + 'start_epoch': '1', + 'end_epoch': '2', + }] + + self.expected_relation_results = [{ + RELATION_START_KEY: 'hive://gold.base/test/col/avg/', + RELATION_START_LABEL: 'Stat', + RELATION_END_KEY: 'hive://gold.base/test/col', + RELATION_END_LABEL: 'Column', + RELATION_TYPE: 'STAT_OF', + RELATION_REVERSE_TYPE: 'STAT' + }] + + def test_get_column_stat_model_key(self) -> None: + table_stats = self.table_stats.get_column_stat_model_key() + self.assertEqual(table_stats, 'hive://gold.base/test/col/avg/') + + def test_get_col_key(self) -> None: + metadata = self.table_stats.get_col_key() + self.assertEqual(metadata, 'hive://gold.base/test/col') + + def test_create_nodes(self) -> None: + actual = [] + node = self.table_stats.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.table_stats.create_next_node() + + self.assertEqual(actual, self.expected_node_results) + + def test_create_relation(self) -> None: + actual = [] + relation = self.table_stats.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.table_stats.create_next_relation() + + self.assertEqual(actual, self.expected_relation_results) + + def test_create_nodes_neptune(self) -> None: + actual = [] + next_node = self.table_stats.create_next_node() + while next_node: + serialized_node = neptune_serializer.convert_node(next_node) + actual.append(serialized_node) + next_node = self.table_stats.create_next_node() + + expected_neptune_nodes = [{ + NEPTUNE_HEADER_ID: 'Stat:hive://gold.base/test/col/avg/', + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: 'hive://gold.base/test/col/avg/', + NEPTUNE_HEADER_LABEL: 'Stat', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'stat_val:String(single)': '1', + 'stat_type:String(single)': 'avg', + 'start_epoch:String(single)': '1', + 'end_epoch:String(single)': '2', + }] + + self.assertEqual(actual, expected_neptune_nodes) + + def test_create_relation_neptune(self) -> None: + self.expected_relation_result = { + RELATION_START_KEY: 'hive://gold.base/test/col/avg/', + RELATION_START_LABEL: 'Stat', + RELATION_END_KEY: 'hive://gold.base/test/col', + RELATION_END_LABEL: 'Column', + RELATION_TYPE: 'STAT_OF', + RELATION_REVERSE_TYPE: 'STAT' + } + + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Stat:hive://gold.base/test/col/avg/', + to_vertex_id='Column:hive://gold.base/test/col', + label='STAT_OF' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Stat:hive://gold.base/test/col/avg/', + to_vertex_id='Column:hive://gold.base/test/col', + label='STAT_OF' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Stat:hive://gold.base/test/col/avg/', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Column:hive://gold.base/test/col', + NEPTUNE_HEADER_LABEL: 'STAT_OF', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.base/test/col', + to_vertex_id='Stat:hive://gold.base/test/col/avg/', + label='STAT' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id='Column:hive://gold.base/test/col', + to_vertex_id='Stat:hive://gold.base/test/col/avg/', + label='STAT' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: 'Column:hive://gold.base/test/col', + NEPTUNE_RELATIONSHIP_HEADER_TO: 'Stat:hive://gold.base/test/col/avg/', + NEPTUNE_HEADER_LABEL: 'STAT', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + actual = [] + next_relation = self.table_stats.create_next_relation() + while next_relation: + serialized_relation = neptune_serializer.convert_relationship(next_relation) + actual.append(serialized_relation) + next_relation = self.table_stats.create_next_relation() + + self.assertListEqual(actual, expected) + + def test_create_records(self) -> None: + expected = [{ + 'rk': 'hive://gold.base/test/col/avg/', + 'stat_val': '1', + 'stat_type': 'avg', + 'start_epoch': '1', + 'end_epoch': '2', + 'column_rk': 'hive://gold.base/test/col' + }] + + actual = [] + record = self.table_stats.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.table_stats.create_next_record() + + self.assertEqual(actual, expected) diff --git a/databuilder/tests/unit/models/test_type_metadata.py b/databuilder/tests/unit/models/test_type_metadata.py new file mode 100644 index 0000000000..dd5c39c836 --- /dev/null +++ b/databuilder/tests/unit/models/test_type_metadata.py @@ -0,0 +1,813 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.table_metadata import ColumnMetadata +from databuilder.models.type_metadata import ( + ArrayTypeMetadata, MapTypeMetadata, ScalarTypeMetadata, StructTypeMetadata, +) +from databuilder.serializers import neo4_serializer + + +class TestTypeMetadata(unittest.TestCase): + def setUp(self) -> None: + self.column_key = 'hive://gold.test_schema1/test_table1/col1' + + def test_serialize_array_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'array>>', 0) + column.set_column_key(self.column_key) + + array_type_metadata = ArrayTypeMetadata( + name='col1', + parent=column, + type_str='array>>' + ) + nested_array_type_metadata_level1 = ArrayTypeMetadata( + name='_inner_', + parent=array_type_metadata, + type_str='array>' + ) + nested_array_type_metadata_level2 = ArrayTypeMetadata( + name='_inner_', + parent=nested_array_type_metadata_level1, + type_str='array' + ) + nested_scalar_type_metadata_level3 = ScalarTypeMetadata( + name='_inner_', + parent=nested_array_type_metadata_level2, + type_str='string' + ) + + array_type_metadata.array_inner_type = nested_array_type_metadata_level1 + nested_array_type_metadata_level1.array_inner_type = nested_array_type_metadata_level2 + nested_array_type_metadata_level2.array_inner_type = nested_scalar_type_metadata_level3 + + expected_nodes = [ + {'kind': 'array', 'name': 'col1', 'LABEL': 'Type_Metadata', 'data_type': 'array>>', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'array', 'name': '_inner_', 'LABEL': 'Type_Metadata', 'data_type': 'array>', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_'}, + {'kind': 'array', 'name': '_inner_', 'LABEL': 'Type_Metadata', 'data_type': 'array', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_inner_'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = array_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = array_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = array_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = array_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_array_map_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'array>>', 0) + column.set_column_key(self.column_key) + + array_type_metadata = ArrayTypeMetadata( + name='col1', + parent=column, + type_str='array>>' + ) + nested_map_type_metadata_level1 = MapTypeMetadata( + name='_inner_', + parent=array_type_metadata, + type_str='map>' + ) + nested_map_key = ScalarTypeMetadata( + name='_map_key', + parent=nested_map_type_metadata_level1, + type_str='string' + ) + nested_array_type_metadata_level2 = ArrayTypeMetadata( + name='_map_value', + parent=nested_map_type_metadata_level1, + type_str='array' + ) + nested_scalar_type_metadata_level3 = ScalarTypeMetadata( + name='_inner_', + parent=nested_array_type_metadata_level2, + type_str='string' + ) + + array_type_metadata.array_inner_type = nested_map_type_metadata_level1 + nested_map_type_metadata_level1.map_key_type = nested_map_key + nested_map_type_metadata_level1.map_value_type = nested_array_type_metadata_level2 + nested_array_type_metadata_level2.array_inner_type = nested_scalar_type_metadata_level3 + + expected_nodes = [ + {'kind': 'array', 'data_type': 'array>>', + 'LABEL': 'Type_Metadata', 'name': 'col1', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'map', 'data_type': 'map>', 'LABEL': 'Type_Metadata', 'name': '_inner_', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_'}, + {'kind': 'scalar', 'data_type': 'string', 'LABEL': 'Type_Metadata', 'name': '_map_key', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_map_key'}, + {'kind': 'array', 'data_type': 'array', 'LABEL': 'Type_Metadata', 'name': '_map_value', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_map_value'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = array_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = array_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = array_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = array_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_array_struct_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'array,c2:string>>', 0) + column.set_column_key(self.column_key) + + array_type_metadata = ArrayTypeMetadata( + name='col1', + parent=column, + type_str='array,c2:string>>' + ) + nested_struct_type_metadata_level1 = StructTypeMetadata( + name='_inner_', + parent=array_type_metadata, + type_str='struct,c2:string>' + ) + nested_array_type_metadata_level2 = ArrayTypeMetadata( + name='c1', + parent=nested_struct_type_metadata_level1, + type_str='array' + ) + nested_scalar_type_metadata_level2 = ScalarTypeMetadata( + name='c2', + parent=nested_struct_type_metadata_level1, + type_str='string' + ) + + array_type_metadata.array_inner_type = nested_struct_type_metadata_level1 + nested_struct_type_metadata_level1.struct_items = {'c1': nested_array_type_metadata_level2, + 'c2': nested_scalar_type_metadata_level2} + nested_array_type_metadata_level2.sort_order = 0 + nested_scalar_type_metadata_level2.sort_order = 1 + + expected_nodes = [ + {'kind': 'array', 'name': 'col1', 'data_type': 'array,c2:string>>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'struct', 'name': '_inner_', 'data_type': 'struct,c2:string>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_'}, + {'kind': 'array', 'name': 'c1', 'data_type': 'array', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 0, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/c1'}, + {'kind': 'scalar', 'name': 'c2', 'data_type': 'string', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 1, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/c2'}, + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/c1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_/c2', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = array_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = array_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = array_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = array_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_map_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'map>>', 0) + column.set_column_key(self.column_key) + + map_type_metadata = MapTypeMetadata( + name='col1', + parent=column, + type_str='map>>' + ) + map_key = ScalarTypeMetadata( + name='_map_key', + parent=map_type_metadata, + type_str='string' + ) + nested_map_type_metadata_level1 = MapTypeMetadata( + name='_map_value', + parent=map_type_metadata, + type_str='map>' + ) + nested_map_key_level1 = ScalarTypeMetadata( + name='_map_key', + parent=nested_map_type_metadata_level1, + type_str='string' + ) + nested_map_type_metadata_level2 = MapTypeMetadata( + name='_map_value', + parent=nested_map_type_metadata_level1, + type_str='map' + ) + nested_map_key_level2 = ScalarTypeMetadata( + name='_map_key', + parent=nested_map_type_metadata_level2, + type_str='string' + ) + nested_scalar_type_metadata_level3 = ScalarTypeMetadata( + name='_map_value', + parent=nested_map_type_metadata_level2, + type_str='string' + ) + + map_type_metadata.map_key_type = map_key + map_type_metadata.map_value_type = nested_map_type_metadata_level1 + nested_map_type_metadata_level1.map_key_type = nested_map_key_level1 + nested_map_type_metadata_level1.map_value_type = nested_map_type_metadata_level2 + nested_map_type_metadata_level2.map_key_type = nested_map_key_level2 + nested_map_type_metadata_level2.map_value_type = nested_scalar_type_metadata_level3 + + expected_nodes = [ + {'kind': 'map', 'name': 'col1', 'data_type': 'map>>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'scalar', 'name': '_map_key', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_key'}, + {'kind': 'map', 'name': '_map_value', 'data_type': 'map>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value'}, + {'kind': 'scalar', 'name': '_map_key', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_key'}, + {'kind': 'map', 'name': '_map_value', 'data_type': 'map', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_value'}, + {'kind': 'scalar', 'name': '_map_key', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_value/_map_key'}, + {'kind': 'scalar', 'name': '_map_value', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_value/_map_value'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Column', 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = map_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = map_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = map_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = map_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_map_struct_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'map,c2:string>>', 0) + column.set_column_key(self.column_key) + + map_type_metadata = MapTypeMetadata( + name='col1', + parent=column, + type_str='map,c2:string>>' + ) + map_key = ScalarTypeMetadata( + name='_map_key', + parent=map_type_metadata, + type_str='string' + ) + nested_struct_type_metadata_level1 = StructTypeMetadata( + name='_map_value', + parent=map_type_metadata, + type_str='struct,c2:string>' + ) + nested_map_type_metadata_level2 = MapTypeMetadata( + name='c1', + parent=nested_struct_type_metadata_level1, + type_str='map' + ) + nested_map_key = ScalarTypeMetadata( + name='_map_key', + parent=nested_map_type_metadata_level2, + type_str='string' + ) + nested_scalar_type_metadata_level3 = ScalarTypeMetadata( + name='_map_value', + parent=nested_map_type_metadata_level2, + type_str='string' + ) + nested_scalar_type_metadata_level2 = ScalarTypeMetadata( + name='c2', + parent=nested_struct_type_metadata_level1, + type_str='string' + ) + + map_type_metadata.map_key_type = map_key + map_type_metadata.map_value_type = nested_struct_type_metadata_level1 + nested_struct_type_metadata_level1.struct_items = {'c1': nested_map_type_metadata_level2, + 'c2': nested_scalar_type_metadata_level2} + nested_map_type_metadata_level2.map_key_type = nested_map_key + nested_map_type_metadata_level2.map_value_type = nested_scalar_type_metadata_level3 + nested_map_type_metadata_level2.sort_order = 0 + nested_scalar_type_metadata_level2.sort_order = 1 + + expected_nodes = [ + {'kind': 'map', 'name': 'col1', 'data_type': 'map,c2:string>>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'scalar', 'name': '_map_key', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_key'}, + {'kind': 'struct', 'name': '_map_value', 'data_type': 'struct,c2:string>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value'}, + {'kind': 'map', 'name': 'c1', 'data_type': 'map', 'sort_order:UNQUOTED': 0, + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1'}, + {'kind': 'scalar', 'name': '_map_key', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1/_map_key'}, + {'kind': 'scalar', 'name': '_map_value', 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1/_map_value'}, + {'kind': 'scalar', 'name': 'c2', 'data_type': 'string', 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 1, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c2'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Column', 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', 'END_LABEL': 'Type_Metadata', + 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value/c2', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_map_value', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', 'TYPE': 'SUBTYPE', + 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = map_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = map_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = map_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = map_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_struct_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'struct>,c5:string>', 0) + column.set_column_key(self.column_key) + + struct_type_metadata = StructTypeMetadata( + name='col1', + parent=column, + type_str='struct>,c5:string>' + ) + nested_struct_type_metadata_level1 = StructTypeMetadata( + name='c1', + parent=struct_type_metadata, + type_str='struct>' + ) + nested_struct_type_metadata_level2 = StructTypeMetadata( + name='c2', + parent=nested_struct_type_metadata_level1, + type_str='struct' + ) + nested_scalar_type_metadata_c3 = ScalarTypeMetadata( + name='c3', + parent=nested_struct_type_metadata_level2, + type_str='string', + ) + nested_scalar_type_metadata_c4 = ScalarTypeMetadata( + name='c4', + parent=nested_struct_type_metadata_level2, + type_str='string' + ) + nested_scalar_type_metadata_c5 = ScalarTypeMetadata( + name='c5', + parent=struct_type_metadata, + type_str='string', + ) + + struct_type_metadata.struct_items = {'c1': nested_struct_type_metadata_level1, + 'c5': nested_scalar_type_metadata_c5} + nested_struct_type_metadata_level1.struct_items = {'c2': nested_struct_type_metadata_level2} + nested_struct_type_metadata_level2.struct_items = {'c3': nested_scalar_type_metadata_c3, + 'c4': nested_scalar_type_metadata_c4} + nested_struct_type_metadata_level1.sort_order = 0 + nested_scalar_type_metadata_c5.sort_order = 1 + nested_struct_type_metadata_level2.sort_order = 0 + nested_scalar_type_metadata_c3.sort_order = 0 + nested_scalar_type_metadata_c4.sort_order = 1 + + nested_scalar_type_metadata_c3.set_description('description of c3') + nested_scalar_type_metadata_c3.set_badges(['badge1']) + nested_scalar_type_metadata_c5.set_description('description of c5') + nested_scalar_type_metadata_c5.set_badges(['badge1', 'badge2']) + + expected_nodes = [ + {'kind': 'struct', 'name': 'col1', + 'data_type': 'struct>,c5:string>', + 'LABEL': 'Type_Metadata', 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'struct', 'name': 'c1', 'data_type': 'struct>', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 0, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1'}, + {'kind': 'struct', 'name': 'c2', 'data_type': 'struct', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 0, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2'}, + {'kind': 'scalar', 'name': 'c3', 'data_type': 'string', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 0, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3'}, + {'description': 'description of c3', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'KEY': 'badge1', 'LABEL': 'Badge', 'category': 'type_metadata'}, + {'kind': 'scalar', 'name': 'c4', 'data_type': 'string', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 1, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c4'}, + {'kind': 'scalar', 'name': 'c5', 'data_type': 'string', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 1, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5'}, + {'description': 'description of c5', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'KEY': 'badge1', 'LABEL': 'Badge', 'category': 'type_metadata'}, + {'KEY': 'badge2', 'LABEL': 'Badge', 'category': 'type_metadata'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3/_description', + 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Description', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'badge1', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c3', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2/c4', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/c2', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5/_description', + 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Description', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'badge1', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'}, + {'END_KEY': 'badge2', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c5', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'} + ] + + node_row = struct_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = struct_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = struct_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = struct_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_serialize_struct_map_array_type_metadata(self) -> None: + column = ColumnMetadata('col1', None, 'struct>,c2:array>', 0) + column.set_column_key(self.column_key) + + struct_type_metadata = StructTypeMetadata( + name='col1', + parent=column, + type_str='struct>,c2:array>' + ) + nested_map_type_metadata_level1 = MapTypeMetadata( + name='c1', + parent=struct_type_metadata, + type_str='map>', + ) + nested_map_key = ScalarTypeMetadata( + name='_map_key', + parent=nested_map_type_metadata_level1, + type_str='string' + ) + nested_array_type_metadata_level2 = ArrayTypeMetadata( + name='_map_value', + parent=nested_map_type_metadata_level1, + type_str='array' + ) + nested_array_type_metadata_level1 = ArrayTypeMetadata( + name='c2', + parent=struct_type_metadata, + type_str='array', + ) + + struct_type_metadata.struct_items = {'c1': nested_map_type_metadata_level1, + 'c2': nested_array_type_metadata_level1} + nested_map_type_metadata_level1.map_key_type = nested_map_key + nested_map_type_metadata_level1.map_value_type = nested_array_type_metadata_level2 + nested_map_type_metadata_level1.sort_order = 0 + nested_array_type_metadata_level1.sort_order = 1 + + nested_map_type_metadata_level1.set_description('description of map') + nested_map_type_metadata_level1.set_badges(['badge1']) + nested_array_type_metadata_level1.set_description('description of array') + nested_array_type_metadata_level1.set_badges(['badge1', 'badge2']) + + expected_nodes = [ + {'kind': 'struct', 'name': 'col1', 'LABEL': 'Type_Metadata', + 'data_type': 'struct>,c2:array>', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'map', 'name': 'c1', 'data_type': 'map>', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 0, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1'}, + {'description': 'description of map', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'KEY': 'badge1', 'LABEL': 'Badge', 'category': 'type_metadata'}, + {'kind': 'scalar', 'name': '_map_key', + 'data_type': 'string', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_map_key'}, + {'kind': 'array', 'name': '_map_value', + 'data_type': 'array', 'LABEL': 'Type_Metadata', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_map_value'}, + {'kind': 'array', 'name': 'c2', 'data_type': 'array', + 'LABEL': 'Type_Metadata', 'sort_order:UNQUOTED': 1, + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2'}, + {'description': 'description of array', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2/_description', + 'LABEL': 'Description', 'description_source': 'description'}, + {'KEY': 'badge1', 'LABEL': 'Badge', 'category': 'type_metadata'}, + {'KEY': 'badge2', 'LABEL': 'Badge', 'category': 'type_metadata'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_description', + 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Description', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'badge1', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_map_key', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1/_map_value', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2/_description', + 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Description', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2', + 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}, + {'END_KEY': 'badge1', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'}, + {'END_KEY': 'badge2', 'START_LABEL': 'Type_Metadata', 'END_LABEL': 'Badge', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/c2', + 'TYPE': 'HAS_BADGE', 'REVERSE_TYPE': 'BADGE_FOR'} + ] + + node_row = struct_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = struct_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = struct_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = struct_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + def test_set_unsupported_descriptions_and_badges(self) -> None: + column = ColumnMetadata('col1', None, 'array>', 0) + column.set_column_key(self.column_key) + + array_type_metadata = ArrayTypeMetadata( + name='col1', + parent=column, + type_str='array>' + ) + nested_array_type_metadata_level1 = ArrayTypeMetadata( + name='_inner_', + parent=array_type_metadata, + type_str='array' + ) + nested_scalar_type_metadata_level2 = ScalarTypeMetadata( + name='_inner_', + parent=nested_array_type_metadata_level1, + type_str='string' + ) + + array_type_metadata.array_inner_type = nested_array_type_metadata_level1 + nested_array_type_metadata_level1.array_inner_type = nested_scalar_type_metadata_level2 + + # Descriptions and badges are set, but they do not appear in the expected nodes and relations + # since they are unsupported for those with parents of ColumnMetadata or ArrayTypeMetadata types + array_type_metadata.set_description('description 1') + array_type_metadata.set_badges(['badge1']) + nested_array_type_metadata_level1.set_description('description 2') + nested_array_type_metadata_level1.set_badges(['badge1']) + + expected_nodes = [ + {'kind': 'array', 'name': 'col1', 'LABEL': 'Type_Metadata', 'data_type': 'array>', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1'}, + {'kind': 'array', 'name': '_inner_', 'LABEL': 'Type_Metadata', 'data_type': 'array', + 'KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_'} + ] + expected_rels = [ + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Column', + 'TYPE': 'TYPE_METADATA', 'REVERSE_TYPE': 'TYPE_METADATA_OF'}, + {'END_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1/_inner_', + 'START_KEY': 'hive://gold.test_schema1/test_table1/col1/type/col1', + 'END_LABEL': 'Type_Metadata', 'START_LABEL': 'Type_Metadata', + 'TYPE': 'SUBTYPE', 'REVERSE_TYPE': 'SUBTYPE_OF'} + ] + + node_row = array_type_metadata.next_node() + actual = [] + while node_row: + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) + node_row = array_type_metadata.next_node() + for i in range(0, len(expected_nodes)): + self.assertEqual(actual[i], expected_nodes[i]) + + relation_row = array_type_metadata.next_relation() + actual = [] + while relation_row: + relation_row_serialized = neo4_serializer.serialize_relationship( + relation_row + ) + actual.append(relation_row_serialized) + relation_row = array_type_metadata.next_relation() + for i in range(0, len(expected_rels)): + self.assertEqual(actual[i], expected_rels[i]) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/models/test_user.py b/databuilder/tests/unit/models/test_user.py new file mode 100644 index 0000000000..286f5de552 --- /dev/null +++ b/databuilder/tests/unit/models/test_user.py @@ -0,0 +1,224 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_serializable import ( + RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, RELATION_START_LABEL, + RELATION_TYPE, +) +from databuilder.models.user import User +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + + +class TestUser(unittest.TestCase): + + def setUp(self) -> None: + super(TestUser, self).setUp() + self.user = User(first_name='test_first', + last_name='test_last', + full_name='test_first test_last', + email='test@email.com', + github_username='github_test', + team_name='test_team', + employee_type='FTE', + manager_email='test_manager@email.com', + slack_id='slack', + is_active=True, + profile_url='https://profile', + updated_at=1, + role_name='swe') + + def test_get_user_model_key(self) -> None: + user_email = User.get_user_model_key(email=self.user.email) + self.assertEqual(user_email, 'test@email.com') + + def test_create_nodes(self) -> None: + expected_nodes = [{ + 'LABEL': 'User', + 'KEY': 'test@email.com', + 'email': 'test@email.com', + 'is_active:UNQUOTED': True, + 'profile_url': 'https://profile', + 'first_name': 'test_first', + 'last_name': 'test_last', + 'full_name': 'test_first test_last', + 'github_username': 'github_test', + 'team_name': 'test_team', + 'employee_type': 'FTE', + 'slack_id': 'slack', + 'role_name': 'swe', + 'updated_at:UNQUOTED': 1 + }] + + actual = [] + node = self.user.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.user.create_next_node() + + self.assertEqual(actual, expected_nodes) + + def test_create_node_additional_attr(self) -> None: + test_user = User(first_name='test_first', + last_name='test_last', + full_name='test_first test_last', + email='test@email.com', + github_username='github_test', + team_name='test_team', + employee_type='FTE', + manager_email='test_manager@email.com', + slack_id='slack', + is_active=True, + profile_url='https://profile', + updated_at=1, + role_name='swe', + enable_notify=True) + node = test_user.create_next_node() + serialized_node = neo4_serializer.serialize_node(node) + self.assertEqual(serialized_node['email'], 'test@email.com') + self.assertEqual(serialized_node['role_name'], 'swe') + self.assertTrue(serialized_node['enable_notify:UNQUOTED']) + + def test_create_node_additional_attr_neptune(self) -> None: + test_user = User(first_name='test_first', + last_name='test_last', + name='test_first test_last', + email='test@email.com', + github_username='github_test', + team_name='test_team', + employee_type='FTE', + manager_email='test_manager@email.com', + slack_id='slack', + is_active=True, + profile_url='https://profile', + updated_at=1, + role_name='swe', + enable_notify=True) + node = test_user.create_next_node() + serialized_node = neptune_serializer.convert_node(node) + self.assertEqual(serialized_node['email:String(single)'], 'test@email.com') + self.assertEqual(serialized_node['role_name:String(single)'], 'swe') + self.assertTrue(serialized_node['enable_notify:Bool(single)']) + + def test_create_record_additional_attr_mysql(self) -> None: + test_user = User(first_name='test_first', + last_name='test_last', + name='test_first test_last', + email='test@email.com', + github_username='github_test', + team_name='test_team', + employee_type='FTE', + manager_email='test_manager@email.com', + slack_id='slack', + is_active=True, + profile_url='https://profile', + updated_at=1, + role_name='swe', + enable_notify=True) + record = test_user.create_next_record() + serialized_record = mysql_serializer.serialize_record(record) + self.assertEqual(serialized_record['email'], 'test@email.com') + self.assertEqual(serialized_record['role_name'], 'swe') + + def test_create_relation(self) -> None: + actual = [] + relation = self.user.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.user.create_next_relation() + + start_key = 'test@email.com' + end_key = 'test_manager@email.com' + + expected_relations = [{ + RELATION_START_KEY: start_key, + RELATION_START_LABEL: User.USER_NODE_LABEL, + RELATION_END_KEY: end_key, + RELATION_END_LABEL: User.USER_NODE_LABEL, + RELATION_TYPE: User.USER_MANAGER_RELATION_TYPE, + RELATION_REVERSE_TYPE: User.MANAGER_USER_RELATION_TYPE + }] + + self.assertTrue(expected_relations, actual) + + def test_create_relation_neptune(self) -> None: + actual = [] + relation = self.user.create_next_relation() + while relation: + serialized = neptune_serializer.convert_relationship(relation) + actual.append(serialized) + relation = self.user.create_next_relation() + + start_key = 'User:{email}'.format(email='test@email.com') + end_key = 'User:{email}'.format(email='test_manager@email.com') + + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=end_key, + label=User.USER_MANAGER_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=start_key, + to_vertex_id=end_key, + label=User.USER_MANAGER_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: end_key, + NEPTUNE_HEADER_LABEL: User.USER_MANAGER_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=end_key, + to_vertex_id=start_key, + label=User.MANAGER_USER_RELATION_TYPE + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id=end_key, + to_vertex_id=start_key, + label=User.MANAGER_USER_RELATION_TYPE + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: end_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: start_key, + NEPTUNE_HEADER_LABEL: User.MANAGER_USER_RELATION_TYPE, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + self.assertListEqual(actual, expected) + + def test_not_including_empty_attribute(self) -> None: + test_user = User(email='test@email.com', + foo='bar') + + self.assertDictEqual(neo4_serializer.serialize_node(test_user.create_next_node()), + {'KEY': 'test@email.com', 'LABEL': 'User', 'email': 'test@email.com', + 'is_active:UNQUOTED': True, 'profile_url': '', 'first_name': '', 'last_name': '', + 'full_name': '', 'github_username': '', 'team_name': '', 'employee_type': '', + 'slack_id': '', 'role_name': '', 'updated_at:UNQUOTED': 0, 'foo': 'bar'}) + + test_user2 = User(email='test@email.com', + foo='bar', + is_active=False, + do_not_update_empty_attribute=True) + + self.assertDictEqual(neo4_serializer.serialize_node(test_user2.create_next_node()), + {'KEY': 'test@email.com', 'LABEL': 'User', 'email': 'test@email.com', 'foo': 'bar'}) diff --git a/databuilder/tests/unit/models/test_user_elasticsearch_document.py b/databuilder/tests/unit/models/test_user_elasticsearch_document.py new file mode 100644 index 0000000000..7ac379a178 --- /dev/null +++ b/databuilder/tests/unit/models/test_user_elasticsearch_document.py @@ -0,0 +1,53 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from databuilder.models.user_elasticsearch_document import UserESDocument + + +class TestUserElasticsearchDocument(unittest.TestCase): + + def test_to_json(self) -> None: + """ + Test string generated from to_json method + """ + test_obj = UserESDocument(email='test@email.com', + first_name='test_firstname', + last_name='test_lastname', + full_name='full_name', + github_username='github_user', + team_name='team', + employee_type='fte', + manager_email='test_manager', + slack_id='test_slack', + role_name='role_name', + is_active=True, + total_read=2, + total_own=3, + total_follow=1) + + expected_document_dict = {"first_name": "test_firstname", + "last_name": "test_lastname", + "full_name": "full_name", + "team_name": "team", + "total_follow": 1, + "total_read": 2, + "is_active": True, + "total_own": 3, + "slack_id": 'test_slack', + "role_name": 'role_name', + "manager_email": "test_manager", + 'github_username': "github_user", + "employee_type": 'fte', + "email": "test@email.com", + } + + result = test_obj.to_json() + results = result.split("\n") + + # verify two new line characters in result + self.assertEqual(len(results), 2, "Result from to_json() function doesn't have a newline!") + + self.assertDictEqual(json.loads(results[0]), expected_document_dict) diff --git a/databuilder/tests/unit/models/test_watermark.py b/databuilder/tests/unit/models/test_watermark.py new file mode 100644 index 0000000000..17ae887a7d --- /dev/null +++ b/databuilder/tests/unit/models/test_watermark.py @@ -0,0 +1,200 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import ANY + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.models.watermark import Watermark +from databuilder.serializers import ( + mysql_serializer, neo4_serializer, neptune_serializer, +) +from databuilder.serializers.neptune_serializer import ( + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_CREATION_TYPE_JOB, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_HEADER_ID, NEPTUNE_HEADER_LABEL, + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT, NEPTUNE_RELATIONSHIP_HEADER_FROM, + NEPTUNE_RELATIONSHIP_HEADER_TO, +) + +CREATE_TIME = '2017-09-18T00:00:00' +DATABASE = 'DYNAMO' +SCHEMA = 'BASE' +TABLE = 'TEST' +NESTED_PART = 'ds=2017-09-18/feature_id=9' +CLUSTER = 'DEFAULT' +PART_TYPE = 'LOW_WATERMARK' + + +class TestWatermark(unittest.TestCase): + + def setUp(self) -> None: + super(TestWatermark, self).setUp() + self.watermark = Watermark( + create_time='2017-09-18T00:00:00', + database=DATABASE, + schema=SCHEMA, + table_name=TABLE, + cluster=CLUSTER, + part_type=PART_TYPE, + part_name=NESTED_PART + ) + self.start_key = f'{DATABASE}://{CLUSTER}.{SCHEMA}/{TABLE}/{PART_TYPE}/' + self.end_key = f'{DATABASE}://{CLUSTER}.{SCHEMA}/{TABLE}' + self.expected_node_result = GraphNode( + key=self.start_key, + label='Watermark', + attributes={ + 'partition_key': 'ds', + 'partition_value': '2017-09-18/feature_id=9', + 'create_time': '2017-09-18T00:00:00' + } + ) + + self.expected_serialized_node_results = [{ + NODE_KEY: self.start_key, + NODE_LABEL: 'Watermark', + 'partition_key': 'ds', + 'partition_value': '2017-09-18/feature_id=9', + 'create_time': '2017-09-18T00:00:00' + }] + + self.expected_relation_result = GraphRelationship( + start_label='Watermark', + end_label='Table', + start_key=self.start_key, + end_key=self.end_key, + type='BELONG_TO_TABLE', + reverse_type='WATERMARK', + attributes={} + ) + + self.expected_serialized_relation_results = [{ + RELATION_START_KEY: self.start_key, + RELATION_START_LABEL: 'Watermark', + RELATION_END_KEY: self.end_key, + RELATION_END_LABEL: 'Table', + RELATION_TYPE: 'BELONG_TO_TABLE', + RELATION_REVERSE_TYPE: 'WATERMARK' + }] + + def test_get_watermark_model_key(self) -> None: + watermark = self.watermark.get_watermark_model_key() + self.assertEqual(watermark, f'{DATABASE}://{CLUSTER}.{SCHEMA}/{TABLE}/{PART_TYPE}/') + + def test_get_metadata_model_key(self) -> None: + metadata = self.watermark.get_metadata_model_key() + self.assertEqual(metadata, f'{DATABASE}://{CLUSTER}.{SCHEMA}/{TABLE}') + + def test_create_nodes(self) -> None: + actual = [] + node = self.watermark.create_next_node() + while node: + serialized_node = neo4_serializer.serialize_node(node) + actual.append(serialized_node) + node = self.watermark.create_next_node() + + self.assertEqual(actual, self.expected_serialized_node_results) + + def test_create_nodes_neptune(self) -> None: + expected_serialized_node_results = [{ + NEPTUNE_HEADER_ID: 'Watermark:' + self.start_key, + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: self.start_key, + NEPTUNE_HEADER_LABEL: 'Watermark', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_NODE_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB, + 'partition_key:String(single)': 'ds', + 'partition_value:String(single)': '2017-09-18/feature_id=9', + 'create_time:String(single)': '2017-09-18T00:00:00' + }] + + actual = [] + node = self.watermark.create_next_node() + while node: + serialized_node = neptune_serializer.convert_node(node) + actual.append(serialized_node) + node = self.watermark.create_next_node() + + self.assertEqual(expected_serialized_node_results, actual) + + def test_create_relation(self) -> None: + actual = [] + relation = self.watermark.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.watermark.create_next_relation() + + self.assertEqual(actual, self.expected_serialized_relation_results) + + def test_create_relation_neptune(self) -> None: + actual = [] + relation = self.watermark.create_next_relation() + while relation: + serialized_relation = neptune_serializer.convert_relationship(relation) + actual.append(serialized_relation) + relation = self.watermark.create_next_relation() + + expected = [ + [ + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Watermark:" + self.start_key, + to_vertex_id="Table:" + self.end_key, + label='BELONG_TO_TABLE' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Watermark:" + self.start_key, + to_vertex_id="Table:" + self.end_key, + label='BELONG_TO_TABLE' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Watermark:" + self.start_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Table:" + self.end_key, + NEPTUNE_HEADER_LABEL: 'BELONG_TO_TABLE', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + }, + { + NEPTUNE_HEADER_ID: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + self.end_key, + to_vertex_id="Watermark:" + self.start_key, + label='WATERMARK' + ), + METADATA_KEY_PROPERTY_NAME_BULK_LOADER_FORMAT: "{label}:{from_vertex_id}_{to_vertex_id}".format( + from_vertex_id="Table:" + self.end_key, + to_vertex_id="Watermark:" + self.start_key, + label='WATERMARK' + ), + NEPTUNE_RELATIONSHIP_HEADER_FROM: "Table:" + self.end_key, + NEPTUNE_RELATIONSHIP_HEADER_TO: "Watermark:" + self.start_key, + NEPTUNE_HEADER_LABEL: 'WATERMARK', + NEPTUNE_LAST_EXTRACTED_AT_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: ANY, + NEPTUNE_CREATION_TYPE_RELATIONSHIP_PROPERTY_NAME_BULK_LOADER_FORMAT: NEPTUNE_CREATION_TYPE_JOB + } + ] + ] + + self.assertListEqual(actual, expected) + + def test_create_records(self) -> None: + expected = [{ + 'rk': self.start_key, + 'partition_key': 'ds', + 'partition_value': '2017-09-18/feature_id=9', + 'create_time': '2017-09-18T00:00:00', + 'table_rk': self.end_key + }] + + actual = [] + record = self.watermark.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.watermark.create_next_record() + + self.assertEqual(actual, expected) diff --git a/databuilder/tests/unit/models/usage/test_usage.py b/databuilder/tests/unit/models/usage/test_usage.py new file mode 100644 index 0000000000..26cb74b913 --- /dev/null +++ b/databuilder/tests/unit/models/usage/test_usage.py @@ -0,0 +1,101 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from databuilder.models.usage.usage import Usage +from databuilder.serializers import mysql_serializer, neo4_serializer + + +class TestUsage(unittest.TestCase): + + def setUp(self) -> None: + self.usage = Usage( + start_label='Table', + start_key='the_key', + user_email='foo@bar.biz', + read_count=42, + ) + + self.expected_nodes = [ + { + 'KEY': 'foo@bar.biz', + 'LABEL': 'User', + 'email': 'foo@bar.biz', + } + ] + + self.expected_relations = [ + { + 'START_LABEL': 'Table', + 'END_LABEL': 'User', + 'START_KEY': 'the_key', + 'END_KEY': 'foo@bar.biz', + 'TYPE': 'READ_BY', + 'REVERSE_TYPE': 'READ', + 'read_count:UNQUOTED': 42, + }, + ] + + self.expected_records = [ + { + 'rk': 'foo@bar.biz', + 'email': 'foo@bar.biz', + }, + { + 'table_rk': 'the_key', + 'user_rk': 'foo@bar.biz', + 'read_count': 42, + } + ] + + def test_usage_not_supported(self) -> None: + with self.assertRaises(Exception) as e: + Usage( + start_label='User', # users can't have usage + start_key='user@user.us', + user_email='another_user@user.us', + ) + self.assertEqual(e.exception.args, ('usage for User is not supported',)) + + def test_usage_nodes(self) -> None: + node = self.usage.next_node() + actual = [] + while node: + node_serialized = neo4_serializer.serialize_node(node) + actual.append(node_serialized) + node = self.usage.next_node() + + self.assertEqual(actual, self.expected_nodes) + + def test_usage_relations(self) -> None: + actual = [] + relation = self.usage.create_next_relation() + while relation: + serialized_relation = neo4_serializer.serialize_relationship(relation) + actual.append(serialized_relation) + relation = self.usage.create_next_relation() + + self.assertEqual(actual, self.expected_relations) + + def test_usage_record(self) -> None: + actual = [] + record = self.usage.create_next_record() + while record: + serialized_record = mysql_serializer.serialize_record(record) + actual.append(serialized_record) + record = self.usage.create_next_record() + + self.assertEqual(actual, self.expected_records) + + def test_usage_not_table_serializable(self) -> None: + feature_usage = Usage( + start_label='Feature', + start_key='feature://a/b/c', + user_email='user@user.us', + ) + with self.assertRaises(Exception) as e: + record = feature_usage.create_next_record() + while record: + record = feature_usage.create_next_record() + self.assertEqual(e.exception.args, ('Feature usage is not table serializable',)) diff --git a/databuilder/tests/unit/publisher/__init__.py b/databuilder/tests/unit/publisher/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/publisher/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/publisher/test_atlas_csv_publisher.py b/databuilder/tests/unit/publisher/test_atlas_csv_publisher.py new file mode 100644 index 0000000000..728e397466 --- /dev/null +++ b/databuilder/tests/unit/publisher/test_atlas_csv_publisher.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +import unittest +from unittest.mock import MagicMock + +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.publisher.atlas_csv_publisher import AtlasCSVPublisher + + +class TestAtlasCsvPublisher(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self._resource_path = os.path.join(os.path.dirname(__file__), '../resources/atlas_csv_publisher') + self.mock_atlas_client = MagicMock() + + config_dict = { + 'publisher.atlas_csv_publisher.entity_files_directory': f'{self._resource_path}/entities', + 'publisher.atlas_csv_publisher.relationship_files_directory': f'{self._resource_path}/relationships', + 'publisher.atlas_csv_publisher.batch_size': 1, + 'publisher.atlas_csv_publisher.atlas_client': self.mock_atlas_client, + } + + self._conf = ConfigFactory.from_dict(config_dict) + + def test_publisher(self) -> None: + publisher = AtlasCSVPublisher() + publisher.init( + conf=Scoped.get_scoped_conf( + conf=self._conf, + scope=publisher.get_scope(), + ), + ) + publisher.publish() + + # 4 entities to create + self.assertEqual(self.mock_atlas_client.entity.create_entities.call_count, 4) + + # 1 entity to update + self.assertEqual(self.mock_atlas_client.entity.update_entity.call_count, 1) + + # 2 relationships to create + self.assertEqual(self.mock_atlas_client.relationship.create_relationship.call_count, 2) diff --git a/databuilder/tests/unit/publisher/test_elasticsearch_publisher.py b/databuilder/tests/unit/publisher/test_elasticsearch_publisher.py new file mode 100644 index 0000000000..ff3c79f2a3 --- /dev/null +++ b/databuilder/tests/unit/publisher/test_elasticsearch_publisher.py @@ -0,0 +1,121 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from mock import ( + MagicMock, mock_open, patch, +) +from pyhocon import ConfigFactory + +from databuilder import Scoped +from databuilder.publisher.elasticsearch_publisher import ElasticsearchPublisher + + +class TestElasticsearchPublisher(unittest.TestCase): + + def setUp(self) -> None: + self.test_file_path = 'test_publisher_file.json' + self.test_file_mode = 'r' + + self.mock_es_client = MagicMock() + self.test_es_new_index = 'test_new_index' + self.test_es_alias = 'test_index_alias' + self.test_doc_type = 'test_doc_type' + + config_dict = {'publisher.elasticsearch.file_path': self.test_file_path, + 'publisher.elasticsearch.mode': self.test_file_mode, + 'publisher.elasticsearch.client': self.mock_es_client, + 'publisher.elasticsearch.new_index': self.test_es_new_index, + 'publisher.elasticsearch.alias': self.test_es_alias, + 'publisher.elasticsearch.doc_type': self.test_doc_type} + + self.conf = ConfigFactory.from_dict(config_dict) + + def test_publish_with_no_data(self) -> None: + """ + Test Publish functionality with no data + """ + with patch('builtins.open', mock_open(read_data='')) as mock_file: + publisher = ElasticsearchPublisher() + publisher.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=publisher.get_scope())) + + # assert mock was called with test_file_path and test_file_mode + mock_file.assert_called_with(self.test_file_path, self.test_file_mode) + + publisher.publish() + # no calls should be made through elasticseach_client when there is no data + self.assertTrue(self.mock_es_client.call_count == 0) + + def test_publish_with_data_and_no_old_index(self) -> None: + """ + Test Publish functionality with data but no index in place + """ + mock_data = json.dumps({'KEY_DOESNOT_MATTER': 'NO_VALUE', + 'KEY_DOESNOT_MATTER2': 'NO_VALUE2'}) + self.mock_es_client.indices.get_alias.return_value = {} + + with patch('builtins.open', mock_open(read_data=mock_data)) as mock_file: + publisher = ElasticsearchPublisher() + publisher.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=publisher.get_scope())) + + # assert mock was called with test_file_path and test_file_mode + mock_file.assert_called_once_with(self.test_file_path, self.test_file_mode) + + publisher.publish() + # ensure indices create endpoint was called + default_mapping = ElasticsearchPublisher.DEFAULT_ELASTICSEARCH_INDEX_MAPPING + self.mock_es_client.indices.create.assert_called_once_with(index=self.test_es_new_index, + body=default_mapping) + + # bulk endpoint called once + self.mock_es_client.bulk.assert_called_once_with( + [{'index': {'_index': self.test_es_new_index}}, + {'KEY_DOESNOT_MATTER': 'NO_VALUE', + 'KEY_DOESNOT_MATTER2': 'NO_VALUE2', + 'resource_type': 'test_doc_type'}] + ) + + # update alias endpoint called once + self.mock_es_client.indices.update_aliases.assert_called_once_with( + {'actions': [{"add": {"index": self.test_es_new_index, "alias": self.test_es_alias}}]} + ) + + def test_publish_with_data_and_old_index(self) -> None: + """ + Test Publish functionality with data and with old_index in place + """ + mock_data = json.dumps({'KEY_DOESNOT_MATTER': 'NO_VALUE', + 'KEY_DOESNOT_MATTER2': 'NO_VALUE2'}) + self.mock_es_client.indices.get_alias.return_value = {'test_old_index': 'DOES_NOT_MATTER'} + + with patch('builtins.open', mock_open(read_data=mock_data)) as mock_file: + publisher = ElasticsearchPublisher() + publisher.init(conf=Scoped.get_scoped_conf(conf=self.conf, + scope=publisher.get_scope())) + + # assert mock was called with test_file_path and test_file_mode + mock_file.assert_called_once_with(self.test_file_path, self.test_file_mode) + + publisher.publish() + # ensure indices create endpoint was called + default_mapping = ElasticsearchPublisher.DEFAULT_ELASTICSEARCH_INDEX_MAPPING + self.mock_es_client.indices.create.assert_called_once_with(index=self.test_es_new_index, + body=default_mapping) + + # bulk endpoint called once + self.mock_es_client.bulk.assert_called_once_with( + [{'index': {'_index': self.test_es_new_index}}, + {'KEY_DOESNOT_MATTER': 'NO_VALUE', + 'KEY_DOESNOT_MATTER2': 'NO_VALUE2', + 'resource_type': 'test_doc_type'}] + ) + + # update alias endpoint called once + self.mock_es_client.indices.update_aliases.assert_called_once_with( + {'actions': [{"add": {"index": self.test_es_new_index, "alias": self.test_es_alias}}, + {"remove_index": {"index": 'test_old_index'}}]} + ) diff --git a/databuilder/tests/unit/publisher/test_mysql_csv_publisher.py b/databuilder/tests/unit/publisher/test_mysql_csv_publisher.py new file mode 100644 index 0000000000..9df249fadd --- /dev/null +++ b/databuilder/tests/unit/publisher/test_mysql_csv_publisher.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest +from typing import Any +from unittest.mock import MagicMock, patch + +from freezegun import freeze_time +from pyhocon import ConfigFactory + +from databuilder.publisher import mysql_csv_publisher +from databuilder.publisher.mysql_csv_publisher import MySQLCSVPublisher +from tests.unit.models.test_table_serializable import Base + +here = os.path.dirname(__file__) + + +class TestMySQLPublish(unittest.TestCase): + + def setUp(self) -> None: + resource_path = os.path.join(here, '../resources/mysql_csv_publisher') + self.conf = ConfigFactory.from_dict( + { + MySQLCSVPublisher.CONN_STRING: 'test_connection', + MySQLCSVPublisher.RECORD_FILES_DIR: f'{resource_path}/records', + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'test' + } + ) + + @freeze_time("2021-01-01 01:01:00") + @patch.object(mysql_csv_publisher, 'sessionmaker') + @patch.object(mysql_csv_publisher, 'create_engine') + def test_publisher_old(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + mock_session_factory = MagicMock() + mock_session_maker.return_value = mock_session_factory + + mock_session = MagicMock() + mock_session_factory.return_value = mock_session + + mock_merge = MagicMock() + mock_session.merge = mock_merge + + mock_commit = MagicMock() + mock_session.commit = mock_commit + + mysql_csv_publisher.Base = Base + + publisher = MySQLCSVPublisher() + publisher.init(self.conf) + publisher.publish() + + # 5 records + self.assertEqual(5, mock_merge.call_count) + # 3 record files + self.assertEqual(3, mock_commit.call_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/publisher/test_neo4j_csv_publisher.py b/databuilder/tests/unit/publisher/test_neo4j_csv_publisher.py new file mode 100644 index 0000000000..5c6dd18f6f --- /dev/null +++ b/databuilder/tests/unit/publisher/test_neo4j_csv_publisher.py @@ -0,0 +1,94 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import unittest +import uuid + +from mock import MagicMock, patch +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder.publisher import neo4j_csv_publisher +from databuilder.publisher.neo4j_csv_publisher import Neo4jCsvPublisher + +here = os.path.dirname(__file__) + + +class TestPublish(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self._resource_path = os.path.join(here, '../resources/csv_publisher') + + def test_publisher(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_transaction = MagicMock() + mock_session.begin_transaction.return_value = mock_transaction + + mock_run = MagicMock() + mock_transaction.run = mock_run + mock_commit = MagicMock() + mock_transaction.commit = mock_commit + + publisher = Neo4jCsvPublisher() + + conf = ConfigFactory.from_dict( + {neo4j_csv_publisher.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687', + neo4j_csv_publisher.NODE_FILES_DIR: f'{self._resource_path}/nodes', + neo4j_csv_publisher.RELATION_FILES_DIR: f'{self._resource_path}/relations', + neo4j_csv_publisher.NEO4J_USER: 'neo4j_user', + neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password', + neo4j_csv_publisher.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + publisher.publish() + + self.assertEqual(mock_run.call_count, 6) + + # 2 node files, 1 relation file + self.assertEqual(mock_commit.call_count, 1) + + def test_preprocessor(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_transaction = MagicMock() + mock_session.begin_transaction.return_value = mock_transaction + + mock_run = MagicMock() + mock_transaction.run = mock_run + mock_commit = MagicMock() + mock_transaction.commit = mock_commit + + mock_preprocessor = MagicMock() + mock_preprocessor.is_perform_preprocess.return_value = MagicMock(return_value=True) + mock_preprocessor.preprocess_cypher.return_value = ('MATCH (f:Foo) RETURN f', {}) + + publisher = Neo4jCsvPublisher() + + conf = ConfigFactory.from_dict( + {neo4j_csv_publisher.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687', + neo4j_csv_publisher.NODE_FILES_DIR: f'{self._resource_path}/nodes', + neo4j_csv_publisher.RELATION_FILES_DIR: f'{self._resource_path}/relations', + neo4j_csv_publisher.RELATION_PREPROCESSOR: mock_preprocessor, + neo4j_csv_publisher.NEO4J_USER: 'neo4j_user', + neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password', + neo4j_csv_publisher.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + publisher.publish() + + self.assertEqual(mock_run.call_count, 8) + + # 2 node files, 1 relation file + self.assertEqual(mock_commit.call_count, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py b/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py new file mode 100644 index 0000000000..230319a0c4 --- /dev/null +++ b/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py @@ -0,0 +1,74 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import unittest +import uuid + +from mock import MagicMock, patch +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder.publisher.neo4j_csv_unwind_publisher import Neo4jCsvUnwindPublisher +from databuilder.publisher.publisher_config_constants import Neo4jCsvPublisherConfigs, PublisherConfigs + +here = os.path.dirname(__file__) + + +class TestPublish(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self._resource_path = os.path.join(here, '../resources/csv_publisher') + + def test_publisher_write_exception(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_write_transaction = MagicMock(side_effect=Exception('Could not write')) + mock_session.__enter__.return_value.write_transaction = mock_write_transaction + + publisher = Neo4jCsvUnwindPublisher() + + conf = ConfigFactory.from_dict( + {Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687/', + PublisherConfigs.NODE_FILES_DIR: f'{self._resource_path}/nodes', + PublisherConfigs.RELATION_FILES_DIR: f'{self._resource_path}/relations', + Neo4jCsvPublisherConfigs.NEO4J_USER: 'neo4j_user', + Neo4jCsvPublisherConfigs.NEO4J_PASSWORD: 'neo4j_password', + PublisherConfigs.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + + with self.assertRaises(Exception): + publisher.publish() + + def test_publisher(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_write_transaction = MagicMock() + mock_session.__enter__.return_value.write_transaction = mock_write_transaction + + publisher = Neo4jCsvUnwindPublisher() + + conf = ConfigFactory.from_dict( + {Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687/', + PublisherConfigs.NODE_FILES_DIR: f'{self._resource_path}/nodes', + PublisherConfigs.RELATION_FILES_DIR: f'{self._resource_path}/relations', + Neo4jCsvPublisherConfigs.NEO4J_USER: 'neo4j_user', + Neo4jCsvPublisherConfigs.NEO4J_PASSWORD: 'neo4j_password', + PublisherConfigs.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + publisher.publish() + + # Create 2 indices, write 2 node files, write 1 relation file + self.assertEqual(5, mock_write_transaction.call_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/publisher/test_neo4j_preprocessor.py b/databuilder/tests/unit/publisher/test_neo4j_preprocessor.py new file mode 100644 index 0000000000..41acbf7705 --- /dev/null +++ b/databuilder/tests/unit/publisher/test_neo4j_preprocessor.py @@ -0,0 +1,80 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import textwrap +import unittest +import uuid + +from databuilder.publisher.neo4j_preprocessor import DeleteRelationPreprocessor, NoopRelationPreprocessor + + +class TestNeo4jPreprocessor(unittest.TestCase): + + def testNoopRelationPreprocessor(self) -> None: + preprocessor = NoopRelationPreprocessor() + + self.assertTrue(not preprocessor.is_perform_preprocess()) + + def testDeleteRelationPreprocessor(self) -> None: # noqa: W293 + preprocessor = DeleteRelationPreprocessor() + + self.assertTrue(preprocessor.is_perform_preprocess()) + + preprocessor.filter(start_label='foo_label', + end_label='bar_label', + start_key='foo_key', + end_key='bar_key', + relation='foo_relation', + reverse_relation='bar_relation') + + self.assertTrue(preprocessor.filter(start_label=str(uuid.uuid4()), + end_label=str(uuid.uuid4()), + start_key=str(uuid.uuid4()), + end_key=str(uuid.uuid4()), + relation=str(uuid.uuid4()), + reverse_relation=str(uuid.uuid4()))) + + actual = preprocessor.preprocess_cypher(start_label='foo_label', + end_label='bar_label', + start_key='foo_key', + end_key='bar_key', + relation='foo_relation', + reverse_relation='bar_relation') + + expected = (textwrap.dedent(""" + MATCH (n1:foo_label {key: $start_key })-[r]-(n2:bar_label {key: $end_key }) + + WITH r LIMIT 2 + DELETE r + RETURN count(*) as count; + """), {'start_key': 'foo_key', 'end_key': 'bar_key'}) + + self.assertEqual(expected, actual) + + def testDeleteRelationPreprocessorFilter(self) -> None: + preprocessor = DeleteRelationPreprocessor(label_tuples=[('foo', 'bar')]) + + self.assertTrue(preprocessor.filter(start_label='foo', + end_label='bar', + start_key=str(uuid.uuid4()), + end_key=str(uuid.uuid4()), + relation=str(uuid.uuid4()), + reverse_relation=str(uuid.uuid4()))) + + self.assertTrue(preprocessor.filter(start_label='bar', + end_label='foo', + start_key=str(uuid.uuid4()), + end_key=str(uuid.uuid4()), + relation=str(uuid.uuid4()), + reverse_relation=str(uuid.uuid4()))) + + self.assertFalse(preprocessor.filter(start_label='foz', + end_label='baz', + start_key=str(uuid.uuid4()), + end_key=str(uuid.uuid4()), + relation=str(uuid.uuid4()), + reverse_relation=str(uuid.uuid4()))) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/publisher/test_publisher.py b/databuilder/tests/unit/publisher/test_publisher.py new file mode 100644 index 0000000000..166914253d --- /dev/null +++ b/databuilder/tests/unit/publisher/test_publisher.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import MagicMock +from pyhocon import ConfigTree + +from databuilder.publisher.base_publisher import NoopPublisher, Publisher + + +class TestPublisher(unittest.TestCase): + + def testCallback(self) -> None: + publisher = NoopPublisher() + callback = MagicMock() + publisher.register_call_back(callback) + publisher.publish() + + self.assertTrue(callback.on_success.called) + + def testFailureCallback(self) -> None: + publisher = FailedPublisher() + callback = MagicMock() + publisher.register_call_back(callback) + + try: + publisher.publish() + except Exception: + pass + + self.assertTrue(callback.on_failure.called) + + +class FailedPublisher(Publisher): + def __init__(self) -> None: + super(FailedPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + pass + + def publish_impl(self) -> None: + raise Exception('Bomb') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Actor.csv b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Actor.csv new file mode 100644 index 0000000000..42f6b963c6 --- /dev/null +++ b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Actor.csv @@ -0,0 +1,3 @@ +"name","operation","qualifiedName","typeName","relationships" +"Tom Cruise","CREATE","actor://Tom Cruise","Actor", +"Meg Ryan","CREATE","actor://Meg Ryan","Actor", diff --git a/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_City.csv b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_City.csv new file mode 100644 index 0000000000..15d100ad5f --- /dev/null +++ b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_City.csv @@ -0,0 +1,3 @@ +"name","operation","qualifiedName","typeName","relationships" +"San Diego","CREATE","city://San Diego","City","" +"Oakland","UPDATE","city://Oakland","City","" diff --git a/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Movie.csv b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Movie.csv new file mode 100644 index 0000000000..dc34661068 --- /dev/null +++ b/databuilder/tests/unit/resources/atlas_csv_publisher/entities/000_Movie.csv @@ -0,0 +1,2 @@ +"name","operation","qualifiedName","typeName","relationships" +"Top Gun","CREATE","movie://Top Gun","Movie","actors#ACTOR#actor://Tom Cruise|actors#ACTOR#actor://Meg Ryan" diff --git a/databuilder/tests/unit/resources/atlas_csv_publisher/relationships/001_Movie_City.csv b/databuilder/tests/unit/resources/atlas_csv_publisher/relationships/001_Movie_City.csv new file mode 100644 index 0000000000..f8ad441801 --- /dev/null +++ b/databuilder/tests/unit/resources/atlas_csv_publisher/relationships/001_Movie_City.csv @@ -0,0 +1,3 @@ +"relationshipType","entityType1","entityQualifiedName1","entityType2","entityQualifiedName2" +"FILMED_AT","Movie","movie://Top Gun","City","city://San Diego" +"FILMED_AT","Movie","movie://Top Gun","City","city://Oakland" diff --git a/databuilder/tests/unit/resources/csv_publisher/nodes/test_column.csv b/databuilder/tests/unit/resources/csv_publisher/nodes/test_column.csv new file mode 100644 index 0000000000..9e0f4f7f1b --- /dev/null +++ b/databuilder/tests/unit/resources/csv_publisher/nodes/test_column.csv @@ -0,0 +1,3 @@ +"KEY","name","order_pos:UNQUOTED","type","LABEL" +"presto://gold.test_schema1/test_table1/test_id1","test_id1",1,"bigint","Column" +"presto://gold.test_schema1/test_table1/test_id2","test_id2",2,"bigint","Column" diff --git a/databuilder/tests/unit/resources/csv_publisher/nodes/test_table.csv b/databuilder/tests/unit/resources/csv_publisher/nodes/test_table.csv new file mode 100644 index 0000000000..6c01542bce --- /dev/null +++ b/databuilder/tests/unit/resources/csv_publisher/nodes/test_table.csv @@ -0,0 +1,3 @@ +"KEY","name","LABEL" +"presto://gold.test_schema1/test_table1","test_table1","Table" +"presto://gold.test_schema1/test_table2","test_table2","Table" diff --git a/databuilder/tests/unit/resources/csv_publisher/relations/test_edge_short.csv b/databuilder/tests/unit/resources/csv_publisher/relations/test_edge_short.csv new file mode 100644 index 0000000000..818571e532 --- /dev/null +++ b/databuilder/tests/unit/resources/csv_publisher/relations/test_edge_short.csv @@ -0,0 +1,3 @@ +"START_LABEL","START_KEY","END_LABEL","END_KEY","TYPE","REVERSE_TYPE" +"Table","presto://gold.test_schema1/test_table1","Column","presto://gold.test_schema1/test_table1/test_id1","COLUMN","BELONG_TO_TABLE" +"Table","presto://gold.test_schema1/test_table1","Column","presto://gold.test_schema1/test_table1/test_id2","COLUMN","BELONG_TO_TABLE" diff --git a/databuilder/tests/unit/resources/extractor/feast/fs/feature_store.yaml b/databuilder/tests/unit/resources/extractor/feast/fs/feature_store.yaml new file mode 100644 index 0000000000..c48a32c6aa --- /dev/null +++ b/databuilder/tests/unit/resources/extractor/feast/fs/feature_store.yaml @@ -0,0 +1,5 @@ +project: fs +registry: data/registry.db +provider: local +online_store: + path: data/online_store.db \ No newline at end of file diff --git a/databuilder/tests/unit/resources/extractor/feast/fs/feature_view.py b/databuilder/tests/unit/resources/extractor/feast/fs/feature_view.py new file mode 100644 index 0000000000..48326ea688 --- /dev/null +++ b/databuilder/tests/unit/resources/extractor/feast/fs/feature_view.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import pathlib +import re +from datetime import datetime + +from feast import ( + Entity, Feature, FeatureView, FileSource, KafkaSource, ValueType, +) +from feast.data_format import AvroFormat +from google.protobuf.duration_pb2 import Duration + +# Read data from parquet files. Parquet is convenient for local development mode. For +# production, you can use your favorite DWH, such as BigQuery. See Feast documentation +# for more info. + +root_path = pathlib.Path(__file__).parent.resolve() +driver_hourly_stats = FileSource( + path=f"{root_path}/data/driver_stats.parquet", + event_timestamp_column="event_timestamp", + created_timestamp_column="created", +) + +driver_hourly_stats_kafka_source = KafkaSource( + bootstrap_servers="broker1", + event_timestamp_column="datetime", + created_timestamp_column="datetime", + topic="driver_hourly_stats", + message_format=AvroFormat( + schema_json=re.sub( + "\n[ \t]*\\|", + "", + """'{"type": "record", + |"name": "driver_hourly_stats", + |"fields": [ + | {"name": "conv_rate", "type": "float"}, + | {"name": "acc_rate", "type": "float"}, + | {"name": "avg_daily_trips", "type": "int"}, + | {"name": "datetime", "type": {"type": "long", "logicalType": "timestamp-micros"}}]}'""", + ) + ), +) + +# Define an entity for the driver. You can think of entity as a primary key used to +# fetch features. +driver = Entity( + name="driver_id", + value_type=ValueType.INT64, + description="Internal identifier of the driver", +) + +# Our parquet files contain sample data that includes a driver_id column, timestamps and +# three feature column. Here we define a Feature View that will allow us to serve this +# data to our model online. +driver_hourly_stats_view = FeatureView( + name="driver_hourly_stats", + entities=["driver_id"], + ttl=Duration(seconds=86400 * 1), + features=[ + Feature(name="conv_rate", dtype=ValueType.FLOAT), + Feature(name="acc_rate", dtype=ValueType.FLOAT), + Feature(name="avg_daily_trips", dtype=ValueType.INT64), + ], + online=True, + stream_source=driver_hourly_stats_kafka_source, + batch_source=driver_hourly_stats, + tags={"is_pii": "true"}, +) + +driver_hourly_stats_view.created_timestamp = datetime.strptime( + "2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" +) diff --git a/databuilder/tests/unit/resources/extractor/user/bamboohr/testdata.xml b/databuilder/tests/unit/resources/extractor/user/bamboohr/testdata.xml new file mode 100644 index 0000000000..700b965bbd --- /dev/null +++ b/databuilder/tests/unit/resources/extractor/user/bamboohr/testdata.xml @@ -0,0 +1,39 @@ + + +
+ Display name + First name + Last name + Preferred name + Gender + Job title + Work Phone + Mobile Phone + Work Email + Department + Location + Work Ext. + Employee photo + Photo URL + Can Upload Photo +
+ + + Roald Amundsen + Roald + Amundsen + + Male + Antarctic Explorer + + + roald@amundsen.io + 508 Corporate Marketing + Norway + + true + https://upload.wikimedia.org/wikipedia/commons/thumb/6/6f/Amundsen_in_fur_skins.jpg/440px-Amundsen_in_fur_skins.jpg + no + + +
diff --git a/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Actor.csv b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Actor.csv new file mode 100644 index 0000000000..42f6b963c6 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Actor.csv @@ -0,0 +1,3 @@ +"name","operation","qualifiedName","typeName","relationships" +"Tom Cruise","CREATE","actor://Tom Cruise","Actor", +"Meg Ryan","CREATE","actor://Meg Ryan","Actor", diff --git a/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_City.csv b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_City.csv new file mode 100644 index 0000000000..cf7c1cf7da --- /dev/null +++ b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_City.csv @@ -0,0 +1,3 @@ +"name","operation","qualifiedName","typeName","relationships" +"San Diego","CREATE","city://San Diego","City","" +"Oakland","CREATE","city://Oakland","City","" diff --git a/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Movie.csv b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Movie.csv new file mode 100644 index 0000000000..dc34661068 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/entities/000_Movie.csv @@ -0,0 +1,2 @@ +"name","operation","qualifiedName","typeName","relationships" +"Top Gun","CREATE","movie://Top Gun","Movie","actors#ACTOR#actor://Tom Cruise|actors#ACTOR#actor://Meg Ryan" diff --git a/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/relationships/001_Movie_City.csv b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/relationships/001_Movie_City.csv new file mode 100644 index 0000000000..f8ad441801 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_atlas_csv_loader/movies/relationships/001_Movie_City.csv @@ -0,0 +1,3 @@ +"relationshipType","entityType1","entityQualifiedName1","entityType2","entityQualifiedName2" +"FILMED_AT","Movie","movie://Top Gun","City","city://San Diego" +"FILMED_AT","Movie","movie://Top Gun","City","city://Oakland" diff --git a/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/actor_0.csv b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/actor_0.csv new file mode 100644 index 0000000000..74316ac27b --- /dev/null +++ b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/actor_0.csv @@ -0,0 +1,3 @@ +"rk","name" +"actor://Tom Cruise","Tom Cruise" +"actor://Meg Ryan","Meg Ryan" diff --git a/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_0.csv b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_0.csv new file mode 100644 index 0000000000..0a57cbbf57 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_0.csv @@ -0,0 +1,2 @@ +"rk","name" +"movie://Top Gun","Top Gun" diff --git a/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_actor_1.csv b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_actor_1.csv new file mode 100644 index 0000000000..45c9af73eb --- /dev/null +++ b/databuilder/tests/unit/resources/fs_mysql_csv_loader/records/movie_actor_1.csv @@ -0,0 +1,3 @@ +"movie_rk","actor_rk" +"movie://Top Gun","actor://Tom Cruise" +"movie://Top Gun","actor://Meg Ryan" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Actor_0.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Actor_0.csv new file mode 100644 index 0000000000..659eb864ea --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Actor_0.csv @@ -0,0 +1,3 @@ +"name","KEY","LABEL" +"Top Gun","actor://Tom Cruise","Actor" +"Top Gun","actor://Meg Ryan","Actor" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/City_0.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/City_0.csv new file mode 100644 index 0000000000..a05fba3522 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/City_0.csv @@ -0,0 +1,3 @@ +"name","KEY","LABEL" +"Top Gun","city://San Diego","City" +"Top Gun","city://Oakland","City" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Movie_0.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Movie_0.csv new file mode 100644 index 0000000000..ee36b278fb --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/nodes/Movie_0.csv @@ -0,0 +1,2 @@ +"name","KEY","LABEL" +"Top Gun","movie://Top Gun","Movie" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_Actor_ACTOR.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_Actor_ACTOR.csv new file mode 100644 index 0000000000..8687ae49cd --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_Actor_ACTOR.csv @@ -0,0 +1,3 @@ +"END_KEY","START_LABEL","END_LABEL","START_KEY","TYPE","REVERSE_TYPE" +"actor://Tom Cruise","Movie","Actor","movie://Top Gun","ACTOR","ACTED_IN" +"actor://Meg Ryan","Movie","Actor","movie://Top Gun","ACTOR","ACTED_IN" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_City_FILMED_AT.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_City_FILMED_AT.csv new file mode 100644 index 0000000000..bc81384599 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/movies/relationships/test_Movie_City_FILMED_AT.csv @@ -0,0 +1,3 @@ +"END_KEY","START_LABEL","END_LABEL","START_KEY","TYPE","REVERSE_TYPE" +"city://San Diego","Movie","City","movie://Top Gun","FILMED_AT","APPEARS_IN" +"city://Oakland","Movie","City","movie://Top Gun","FILMED_AT","APPEARS_IN" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_0.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_0.csv new file mode 100644 index 0000000000..ce9e160e91 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_0.csv @@ -0,0 +1,2 @@ +"name","job","KEY","LABEL" +"Taylor","Engineer","person://Taylor","Person" diff --git a/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_1.csv b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_1.csv new file mode 100644 index 0000000000..2fa909c4f0 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neo4j_csv_loader/people/nodes/Person_1.csv @@ -0,0 +1,2 @@ +"name","pet","KEY","LABEL" +"Griffin","Lion","person://Griffin","Person" diff --git a/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Actor_6.csv b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Actor_6.csv new file mode 100644 index 0000000000..2bd671b6f0 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Actor_6.csv @@ -0,0 +1,3 @@ +"~id","~label","last_extracted_datetime:Date(single)","creation_type:String(single)","name:String(single)","published_tag:String(single)","key:String(single)" +"Actor:actor://Tom Cruise","Actor","2020-09-01T01:01:00","job","Top Gun","TESTED","actor://Tom Cruise" +"Actor:actor://Meg Ryan","Actor","2020-09-01T01:01:00","job","Top Gun","TESTED","actor://Meg Ryan" diff --git a/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/City_6.csv b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/City_6.csv new file mode 100644 index 0000000000..3ed59a7c81 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/City_6.csv @@ -0,0 +1,3 @@ +"~id","~label","last_extracted_datetime:Date(single)","creation_type:String(single)","name:String(single)","published_tag:String(single)","key:String(single)" +"City:city://San Diego","City","2020-09-01T01:01:00","job","Top Gun","TESTED","city://San Diego" +"City:city://Oakland","City","2020-09-01T01:01:00","job","Top Gun","TESTED","city://Oakland" diff --git a/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Movie_6.csv b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Movie_6.csv new file mode 100644 index 0000000000..5da52565ed --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neptune_csv_loader/nodes/Movie_6.csv @@ -0,0 +1,2 @@ +"~id","~label","last_extracted_datetime:Date(single)","creation_type:String(single)","name:String(single)","published_tag:String(single)","key:String(single)" +"Movie:movie://Top Gun","Movie","2020-09-01T01:01:00","job","Top Gun","TESTED","movie://Top Gun" diff --git a/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_Actor_ACTOR.csv b/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_Actor_ACTOR.csv new file mode 100644 index 0000000000..77cb35a7b7 --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_Actor_ACTOR.csv @@ -0,0 +1,5 @@ +"~id","~from","~to","~label","last_extracted_datetime:Date(single)","creation_type:String","published_tag:String(single)","key:String(single)" +"ACTOR:Movie:movie://Top Gun_Actor:actor://Tom Cruise","Movie:movie://Top Gun","Actor:actor://Tom Cruise","ACTOR","2020-09-01T01:01:00","job","TESTED","ACTOR:Movie:movie://Top Gun_Actor:actor://Tom Cruise" +"ACTED_IN:Actor:actor://Tom Cruise_Movie:movie://Top Gun","Actor:actor://Tom Cruise","Movie:movie://Top Gun","ACTED_IN","2020-09-01T01:01:00","job","TESTED","ACTED_IN:Actor:actor://Tom Cruise_Movie:movie://Top Gun" +"ACTOR:Movie:movie://Top Gun_Actor:actor://Meg Ryan","Movie:movie://Top Gun","Actor:actor://Meg Ryan","ACTOR","2020-09-01T01:01:00","job","TESTED","ACTOR:Movie:movie://Top Gun_Actor:actor://Meg Ryan" +"ACTED_IN:Actor:actor://Meg Ryan_Movie:movie://Top Gun","Actor:actor://Meg Ryan","Movie:movie://Top Gun","ACTED_IN","2020-09-01T01:01:00","job","TESTED","ACTED_IN:Actor:actor://Meg Ryan_Movie:movie://Top Gun" diff --git a/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_City_FILMED_AT.csv b/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_City_FILMED_AT.csv new file mode 100644 index 0000000000..7b1bec260a --- /dev/null +++ b/databuilder/tests/unit/resources/fs_neptune_csv_loader/relationships/Movie_City_FILMED_AT.csv @@ -0,0 +1,5 @@ +"~id","~from","~to","~label","last_extracted_datetime:Date(single)","creation_type:String","published_tag:String(single)","key:String(single)" +"FILMED_AT:Movie:movie://Top Gun_City:city://San Diego","Movie:movie://Top Gun","City:city://San Diego","FILMED_AT","2020-09-01T01:01:00","job","TESTED","FILMED_AT:Movie:movie://Top Gun_City:city://San Diego" +"APPEARS_IN:City:city://San Diego_Movie:movie://Top Gun","City:city://San Diego","Movie:movie://Top Gun","APPEARS_IN","2020-09-01T01:01:00","job","TESTED","APPEARS_IN:City:city://San Diego_Movie:movie://Top Gun" +"FILMED_AT:Movie:movie://Top Gun_City:city://Oakland","Movie:movie://Top Gun","City:city://Oakland","FILMED_AT","2020-09-01T01:01:00","job","TESTED","FILMED_AT:Movie:movie://Top Gun_City:city://Oakland" +"APPEARS_IN:City:city://Oakland_Movie:movie://Top Gun","City:city://Oakland","Movie:movie://Top Gun","APPEARS_IN","2020-09-01T01:01:00","job","TESTED","APPEARS_IN:City:city://Oakland_Movie:movie://Top Gun" diff --git a/databuilder/tests/unit/resources/mysql_csv_publisher/records/actor_0.csv b/databuilder/tests/unit/resources/mysql_csv_publisher/records/actor_0.csv new file mode 100644 index 0000000000..bcfe9aee9e --- /dev/null +++ b/databuilder/tests/unit/resources/mysql_csv_publisher/records/actor_0.csv @@ -0,0 +1,3 @@ +"rk","name" +"actor://Tom Cruise","Tom Cruise" +"actor://Meg Ryan","Meg Ryan" diff --git a/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_0.csv b/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_0.csv new file mode 100644 index 0000000000..bb66031cf1 --- /dev/null +++ b/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_0.csv @@ -0,0 +1,2 @@ +"rk","name" +"movie://Top Gun","Top Gun" diff --git a/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_actor_1.csv b/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_actor_1.csv new file mode 100644 index 0000000000..ea4b2d271a --- /dev/null +++ b/databuilder/tests/unit/resources/mysql_csv_publisher/records/movie_actor_1.csv @@ -0,0 +1,3 @@ +"movie_rk","actor_rk" +"movie://Top Gun","actor://Tom Cruise" +"movie://Top Gun","actor://Meg Ryan" diff --git a/databuilder/tests/unit/rest_api/__init__.py b/databuilder/tests/unit/rest_api/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/rest_api/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/rest_api/mode_analytics/__init__.py b/databuilder/tests/unit/rest_api/mode_analytics/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/rest_api/mode_analytics/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/rest_api/mode_analytics/test_mode_paginated_rest_api_query.py b/databuilder/tests/unit/rest_api/mode_analytics/test_mode_paginated_rest_api_query.py new file mode 100644 index 0000000000..c0f9d863a3 --- /dev/null +++ b/databuilder/tests/unit/rest_api/mode_analytics/test_mode_paginated_rest_api_query.py @@ -0,0 +1,102 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest + +from mock import call, patch + +from databuilder.rest_api.base_rest_api_query import RestApiQuerySeed +from databuilder.rest_api.mode_analytics.mode_paginated_rest_api_query import ModePaginatedRestApiQuery + +logging.basicConfig(level=logging.INFO) + + +class TestModePaginatedRestApiQuery(unittest.TestCase): + + def test_pagination(self) -> None: + seed_record = [{'foo1': 'bar1'}, + {'foo2': 'bar2'}] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + json_path = 'foo[*].name' + field_names = ['name_field'] + + mock_get.return_value.json.side_effect = [ # need to duplicate for json() is called twice + {'foo': [{'name': 'v1'}, {'name': 'v2'}]}, + {'foo': [{'name': 'v1'}, {'name': 'v2'}]}, + {'foo': [{'name': 'v3'}]}, + {'foo': [{'name': 'v3'}]}, + {'foo': [{'name': 'v4'}, {'name': 'v5'}]}, + {'foo': [{'name': 'v4'}, {'name': 'v5'}]}, + {}, + {} + ] + + query = ModePaginatedRestApiQuery(query_to_join=seed_query, url='foobar', params={}, + json_path=json_path, field_names=field_names, + skip_no_result=True, pagination_json_path='foo[*]', + max_record_size=2) + + expected_list = [ + {'name_field': 'v1', 'foo1': 'bar1'}, + {'name_field': 'v2', 'foo1': 'bar1'}, + {'name_field': 'v3', 'foo1': 'bar1'}, + {'name_field': 'v4', 'foo2': 'bar2'}, + {'name_field': 'v5', 'foo2': 'bar2'} + ] + for actual in query.execute(): + self.assertDictEqual(actual, expected_list.pop(0)) + + self.assertEqual(mock_get.call_count, 4) + + calls = [ + call('foobar?page=1'), + call('foobar?page=2') + ] + mock_get.assert_has_calls(calls, any_order=True) + + def test_no_pagination(self) -> None: + seed_record = [{'foo1': 'bar1'}, + {'foo2': 'bar2'}, + {'foo3': 'bar3'}] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + json_path = 'foo[*].name' + field_names = ['name_field'] + + mock_get.return_value.json.side_effect = [ # need to duplicate for json() is called twice + {'foo': [{'name': 'v1'}, {'name': 'v2'}]}, + {'foo': [{'name': 'v1'}, {'name': 'v2'}]}, + {'foo': [{'name': 'v3'}]}, + {'foo': [{'name': 'v3'}]}, + {'foo': [{'name': 'v4'}, {'name': 'v5'}]}, + {'foo': [{'name': 'v4'}, {'name': 'v5'}]}, + ] + + query = ModePaginatedRestApiQuery(query_to_join=seed_query, url='foobar', params={}, + json_path=json_path, field_names=field_names, + pagination_json_path='foo[*]', + max_record_size=3) + + expected_list = [ + {'name_field': 'v1', 'foo1': 'bar1'}, + {'name_field': 'v2', 'foo1': 'bar1'}, + {'name_field': 'v3', 'foo2': 'bar2'}, + {'name_field': 'v4', 'foo3': 'bar3'}, + {'name_field': 'v5', 'foo3': 'bar3'} + ] + for actual in query.execute(): + self.assertDictEqual(actual, expected_list.pop(0)) + + self.assertEqual(mock_get.call_count, 3) + calls = [ + call('foobar?page=1') + ] + mock_get.assert_has_calls(calls, any_order=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/rest_api/test_query_merger.py b/databuilder/tests/unit/rest_api/test_query_merger.py new file mode 100644 index 0000000000..97aa7a512b --- /dev/null +++ b/databuilder/tests/unit/rest_api/test_query_merger.py @@ -0,0 +1,99 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch + +from databuilder.rest_api.base_rest_api_query import RestApiQuerySeed +from databuilder.rest_api.query_merger import QueryMerger +from databuilder.rest_api.rest_api_query import RestApiQuery + + +class TestQueryMerger(unittest.TestCase): + def setUp(self) -> None: + query_to_join_seed_record = [ + {'foo1': 'bar1', 'dashboard_id': 'd1'}, + {'foo2': 'bar2', 'dashboard_id': 'd3'} + ] + self.query_to_join = RestApiQuerySeed(seed_record=query_to_join_seed_record) + self.json_path = 'foo.name' + self.field_names = ['name_field'] + self.url = 'foobar' + + def test_ensure_record_get_updated(self) -> None: + query_to_merge_seed_record = [ + {'organization': 'amundsen', 'dashboard_id': 'd1'}, + {'organization': 'amundsen-databuilder', 'dashboard_id': 'd2'}, + {'organization': 'amundsen-dashboard', 'dashboard_id': 'd3'}, + ] + query_to_merge = RestApiQuerySeed(seed_record=query_to_merge_seed_record) + query_merger = QueryMerger(query_to_merge=query_to_merge, merge_key='dashboard_id') + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + mock_get.return_value.json.side_effect = [ + {'foo': {'name': 'john'}}, + {'foo': {'name': 'doe'}}, + ] + query = RestApiQuery(query_to_join=self.query_to_join, url=self.url, params={}, + json_path=self.json_path, field_names=self.field_names, + query_merger=query_merger) + results = list(query.execute()) + self.assertEqual(len(results), 2) + self.assertDictEqual( + {'dashboard_id': 'd1', 'foo1': 'bar1', 'name_field': 'john', 'organization': 'amundsen'}, + results[0], + ) + self.assertDictEqual( + {'dashboard_id': 'd3', 'foo2': 'bar2', 'name_field': 'doe', 'organization': 'amundsen-dashboard'}, + results[1], + ) + + def test_exception_raised_with_duplicate_merge_key(self) -> None: + """ + Two records in query_to_merge results have {'dashboard_id': 'd2'}, + exception should be raised + """ + query_to_merge_seed_record = [ + {'organization': 'amundsen', 'dashboard_id': 'd1'}, + {'organization': 'amundsen-databuilder', 'dashboard_id': 'd2'}, + {'organization': 'amundsen-dashboard', 'dashboard_id': 'd2'}, + ] + query_to_merge = RestApiQuerySeed(seed_record=query_to_merge_seed_record) + query_merger = QueryMerger(query_to_merge=query_to_merge, merge_key='dashboard_id') + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + mock_get.return_value.json.side_effect = [ + {'foo': {'name': 'john'}}, + {'foo': {'name': 'doe'}}, + ] + query = RestApiQuery(query_to_join=self.query_to_join, url=self.url, params={}, + json_path=self.json_path, field_names=self.field_names, + query_merger=query_merger) + self.assertRaises(Exception, query.execute()) # type: ignore + + def test_exception_raised_with_missing_merge_key(self) -> None: + """ + No record in query_to_merge results contains {'dashboard_id': 'd3'}, + exception should be raised + """ + query_to_merge_seed_record = [ + {'organization': 'amundsen', 'dashboard_id': 'd1'}, + {'organization': 'amundsen-databuilder', 'dashboard_id': 'd2'}, + ] + query_to_merge = RestApiQuerySeed(seed_record=query_to_merge_seed_record) + query_merger = QueryMerger(query_to_merge=query_to_merge, merge_key='dashboard_id') + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + mock_get.return_value.json.side_effect = [ + {'foo': {'name': 'john'}}, + {'foo': {'name': 'doe'}}, + ] + query = RestApiQuery(query_to_join=self.query_to_join, url=self.url, params={}, + json_path=self.json_path, field_names=self.field_names, + query_merger=query_merger) + self.assertRaises(Exception, query.execute()) # type: ignore + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/rest_api/test_rest_api_failure_handlers.py b/databuilder/tests/unit/rest_api/test_rest_api_failure_handlers.py new file mode 100644 index 0000000000..e115ac4c76 --- /dev/null +++ b/databuilder/tests/unit/rest_api/test_rest_api_failure_handlers.py @@ -0,0 +1,24 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import MagicMock + +from databuilder.rest_api.rest_api_failure_handlers import HttpFailureSkipOnStatus + + +class TestHttpFailureSkipOnStatus(unittest.TestCase): + + def testSkip(self) -> None: + failure_handler = HttpFailureSkipOnStatus([404, 400]) + + exception = MagicMock() + exception.response.status_code = 404 + self.assertTrue(failure_handler.can_skip_failure(exception=exception)) + + exception.response.status_code = 400 + self.assertTrue(failure_handler.can_skip_failure(exception=exception)) + + exception.response.status_code = 500 + self.assertFalse(failure_handler.can_skip_failure(exception=exception)) diff --git a/databuilder/tests/unit/rest_api/test_rest_api_query.py b/databuilder/tests/unit/rest_api/test_rest_api_query.py new file mode 100644 index 0000000000..b1e0271a93 --- /dev/null +++ b/databuilder/tests/unit/rest_api/test_rest_api_query.py @@ -0,0 +1,143 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import patch + +from databuilder.rest_api.base_rest_api_query import EmptyRestApiQuerySeed, RestApiQuerySeed +from databuilder.rest_api.rest_api_query import RestApiQuery + + +class TestRestApiQuery(unittest.TestCase): + + def test_rest_api_query_seed(self) -> None: + rest_api_query = RestApiQuerySeed(seed_record=[ + {'foo': 'bar'}, + {'john': 'doe'} + ]) + + result = [v for v in rest_api_query.execute()] + expected = [ + {'foo': 'bar'}, + {'john': 'doe'} + ] + + self.assertListEqual(expected, result) + + def test_empty_rest_api_query_seed(self) -> None: + rest_api_query = EmptyRestApiQuerySeed() + + result = [v for v in rest_api_query.execute()] + assert len(result) == 1 + + def test_rest_api_query(self) -> None: + + seed_record = [{'foo1': 'bar1'}, + {'foo2': 'bar2'}] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + json_path = 'foo.name' + field_names = ['name_field'] + + mock_get.return_value.json.side_effect = [ + {'foo': {'name': 'john'}}, + {'foo': {'name': 'doe'}}, + ] + query = RestApiQuery(query_to_join=seed_query, url='foobar', params={}, + json_path=json_path, field_names=field_names) + + expected = [ + {'name_field': 'john', 'foo1': 'bar1'}, + {'name_field': 'doe', 'foo2': 'bar2'} + ] + + for actual in query.execute(): + self.assertDictEqual(expected.pop(0), actual) + + def test_rest_api_query_multiple_fields(self) -> None: + + seed_record = [{'foo1': 'bar1'}, + {'foo2': 'bar2'}] + seed_query = RestApiQuerySeed(seed_record=seed_record) + + with patch('databuilder.rest_api.rest_api_query.requests.get') as mock_get: + json_path = 'foo.[name,hobby]' + field_names = ['name_field', 'hobby'] + + mock_get.return_value.json.side_effect = [ + {'foo': {'name': 'john', 'hobby': 'skiing'}}, + {'foo': {'name': 'doe', 'hobby': 'snowboarding'}}, + ] + query = RestApiQuery(query_to_join=seed_query, url='foobar', params={}, + json_path=json_path, field_names=field_names) + + expected = [ + {'name_field': 'john', 'hobby': 'skiing', 'foo1': 'bar1'}, + {'name_field': 'doe', 'hobby': 'snowboarding', 'foo2': 'bar2'} + ] + + for actual in query.execute(): + self.assertDictEqual(expected.pop(0), actual) + + def test_compute_subresult_single_field(self) -> None: + sub_records = RestApiQuery._compute_sub_records(result_list=['1', '2', '3'], field_names=['foo']) + + expected_records = [ + ['1'], ['2'], ['3'] + ] + + assert expected_records == sub_records + + sub_records = RestApiQuery._compute_sub_records(result_list=['1', '2', '3'], field_names=['foo'], + json_path_contains_or=True) + + assert expected_records == sub_records + + def test_compute_subresult_multiple_fields_json_path_and_expression(self) -> None: + sub_records = RestApiQuery._compute_sub_records( + result_list=['1', 'a', '2', 'b', '3', 'c'], field_names=['foo', 'bar']) + + expected_records = [ + ['1', 'a'], ['2', 'b'], ['3', 'c'] + ] + + assert expected_records == sub_records + + sub_records = RestApiQuery._compute_sub_records( + result_list=['1', 'a', 'x', '2', 'b', 'y', '3', 'c', 'z'], field_names=['foo', 'bar', 'baz']) + + expected_records = [ + ['1', 'a', 'x'], ['2', 'b', 'y'], ['3', 'c', 'z'] + ] + + assert expected_records == sub_records + + def test_compute_subresult_multiple_fields_json_path_or_expression(self) -> None: + sub_records = RestApiQuery._compute_sub_records( + result_list=['1', '2', '3', 'a', 'b', 'c'], + field_names=['foo', 'bar'], + json_path_contains_or=True + ) + + expected_records = [ + ['1', 'a'], ['2', 'b'], ['3', 'c'] + ] + + self.assertEqual(expected_records, sub_records) + + sub_records = RestApiQuery._compute_sub_records( + result_list=['1', '2', '3', 'a', 'b', 'c', 'x', 'y', 'z'], + field_names=['foo', 'bar', 'baz'], + json_path_contains_or=True) + + expected_records = [ + ['1', 'a', 'x'], ['2', 'b', 'y'], ['3', 'c', 'z'] + ] + + self.assertEqual(expected_records, sub_records) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/task/__init__.py b/databuilder/tests/unit/task/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/task/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/task/test_mysql_staleness_removal_task.py b/databuilder/tests/unit/task/test_mysql_staleness_removal_task.py new file mode 100644 index 0000000000..1133247a9f --- /dev/null +++ b/databuilder/tests/unit/task/test_mysql_staleness_removal_task.py @@ -0,0 +1,151 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import unittest +from typing import Any +from unittest.mock import patch + +from amundsen_rds.models.table import Table +from pyhocon import ConfigFactory + +from databuilder.publisher.mysql_csv_publisher import MySQLCSVPublisher +from databuilder.task import mysql_staleness_removal_task +from databuilder.task.mysql_staleness_removal_task import MySQLStalenessRemovalTask + + +class TestMySQLStalenessRemovalTask(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_marker(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'foo' + }) + task.init(job_config) + + self.assertIsNone(task.ms_to_expire) + self.assertEqual(task.marker, 'foo') + + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.MS_TO_EXPIRE}': 86400000 + }) + task.init(job_config) + + self.assertIsNotNone(task.ms_to_expire) + self.assertEqual(task.marker, 86400000) + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_config_with_publish_tag_and_ms_to_expire(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.MS_TO_EXPIRE}': 86400000, + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'foo' + }) + + self.assertRaises(Exception, task.init, job_config) + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_ms_to_expire_too_small(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.MS_TO_EXPIRE}': 24 * 60 * 60 * 100, + }) + + self.assertRaises(Exception, task.init, job_config) + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_validation_threshold_override(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_PCT_MAX_DICT}': {'table_metadata': 30}, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.TARGET_TABLES}': ['table_metadata'], + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'foo' + }) + mock_total_records_query = mock_session_maker.return_value.return_value.query.return_value.scalar + mock_total_records_query.return_value = 10 + mock_stale_records_query = mock_session_maker.return_value.return_value \ + .query.return_value.filter.return_value.scalar + mock_stale_records_query.return_value = 5 + + task.init(job_config) + + self.assertRaises(Exception, task._validate_record_staleness_pct, 'table_metadata', Table, 'rk') + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_dry_run(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.TARGET_TABLES}': ['foo'], + f'{task.get_scope()}.{MySQLStalenessRemovalTask.DRY_RUN}': True, + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'foo' + }) + mock_commit = mock_session_maker.return_value.return_value.commit + + task.init(job_config) + + mock_commit.assert_not_called() + + @patch.object(mysql_staleness_removal_task, 'sessionmaker') + @patch.object(mysql_staleness_removal_task, 'create_engine') + def test_stale_records_filter_condition(self, mock_create_engine: Any, mock_session_maker: Any) -> None: + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.TARGET_TABLES}': ['table_metadata'], + MySQLCSVPublisher.JOB_PUBLISH_TAG: 'foo' + }) + + task.init(job_config) + filter_statement = task._get_stale_records_filter_condition(Table) + + self.assertTrue(str(filter_statement) == 'table_metadata.published_tag != :published_tag_1') + + task = MySQLStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + 'job.identifier': 'mysql_remove_stale_data_job', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.CONN_STRING}': 'foobar', + f'{task.get_scope()}.{MySQLStalenessRemovalTask.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{MySQLStalenessRemovalTask.TARGET_TABLES}': ['table_metadata'], + f'{task.get_scope()}.{MySQLStalenessRemovalTask.MS_TO_EXPIRE}': 24 * 60 * 60 * 1000 + + }) + + task.init(job_config) + filter_statement = task._get_stale_records_filter_condition(Table) + + self.assertTrue(str(filter_statement) == 'table_metadata.publisher_last_updated_epoch_ms < ' + ':publisher_last_updated_epoch_ms_1') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py b/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py new file mode 100644 index 0000000000..70512dfe85 --- /dev/null +++ b/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py @@ -0,0 +1,606 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +# Validation of Cypher statements causing Flake8 to fail. Disabling it on this file only +# flake8: noqa + +import logging +import textwrap +import unittest + +from mock import patch +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder.publisher import neo4j_csv_publisher +from databuilder.task import neo4j_staleness_removal_task +from databuilder.task.neo4j_staleness_removal_task import Neo4jStalenessRemovalTask, TargetWithCondition + + +class TestRemoveStaleData(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + def test_validation_failure(self) -> None: + + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 90, + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo' + }) + + task.init(job_config) + total_record_count = 100 + stale_record_count = 50 + target_type = 'foo' + task._validate_staleness_pct(total_record_count=total_record_count, + stale_record_count=stale_record_count, + target_type=target_type) + + def test_validation(self) -> None: + + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo' + }) + + task.init(job_config) + total_record_count = 100 + stale_record_count = 50 + target_type = 'foo' + self.assertRaises(Exception, task._validate_staleness_pct, total_record_count, stale_record_count, target_type) + + def test_validation_threshold_override(self) -> None: + + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_PCT_MAX_DICT}': {'foo': 51}, + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo' + }) + + task.init(job_config) + task._validate_staleness_pct(total_record_count=100, + stale_record_count=50, + target_type='foo') + task._validate_staleness_pct(total_record_count=100, + stale_record_count=3, + target_type='bar') + + def test_marker(self) -> None: + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo' + }) + + task.init(job_config) + self.assertIsNone(task.ms_to_expire) + self.assertEqual(task.marker, 'foo') + + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 86400000, + }) + + task.init(job_config) + self.assertIsNotNone(task.ms_to_expire) + self.assertEqual(task.marker, 86400000) + + def test_validation_statement_publish_tag(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._validate_node_staleness_pct() + + mock_execute.assert_called() + mock_execute.assert_any_call(statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE true + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) + RETURN count(*) as count + """)) + + def test_validation_statement_publish_tag_retain_data_with_no_publisher_metadata(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.RETAIN_DATA_WITH_NO_PUBLISHER_METADATA}': True + }) + + task.init(job_config) + task._validate_node_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker) + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker) + RETURN count(*) as count + """)) + + def test_validation_statement_ms_to_expire(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 9876543210 + }) + + task.init(job_config) + task._validate_node_staleness_pct() + + mock_execute.assert_called() + mock_execute.assert_any_call(statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE true + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(param_dict={'marker': 9876543210}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': 9876543210}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + RETURN count(*) as count + """)) + + def test_validation_statement_ms_to_expire_retain_data_with_no_publisher_metadata(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 9876543210, + f'{task.get_scope()}.{neo4j_staleness_removal_task.RETAIN_DATA_WITH_NO_PUBLISHER_METADATA}': True + }) + + task.init(job_config) + task._validate_node_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': 9876543210}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker)) + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': 9876543210}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker)) + RETURN count(*) as count + """)) + + def test_validation_statement_with_target_condition(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': [TargetWithCondition('Foo', '(target)-[:BAR]->(:Foo) AND target.name=\'foo_name\'')], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': [TargetWithCondition('BAR', '(start_node:Foo)-[target]->(end_node:Foo)')], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._validate_node_staleness_pct() + + mock_execute.assert_called() + mock_execute.assert_any_call(statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE true AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) AND (start_node:Foo)-[target]->(end_node:Foo) + RETURN count(*) as count + """)) + + def test_validation_receives_correct_counts(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + + with patch.object(Neo4jStalenessRemovalTask, '_validate_staleness_pct') as mock_validate: + mock_execute.side_effect = [[{'count': 100}], [{'count': 50}]] + task._validate_node_staleness_pct() + mock_validate.assert_called_with(total_record_count=100, + stale_record_count=50, + target_type='Foo') + + mock_execute.side_effect = [[{'count': 100}], [{'count': 50}]] + task._validate_relation_staleness_pct() + mock_validate.assert_called_with(total_record_count=100, + stale_record_count=50, + target_type='BAR') + + def test_delete_statement_publish_tag(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_delete_statement_publish_tag_retain_data_with_no_publisher_metadata(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.RETAIN_DATA_WITH_NO_PUBLISHER_METADATA}': True + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_delete_statement_ms_to_expire(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 9876543210 + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': 9876543210, 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': 9876543210, 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_delete_statement_ms_to_expire_retain_data_with_no_publisher_metadata(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 9876543210, + f'{task.get_scope()}.{neo4j_staleness_removal_task.RETAIN_DATA_WITH_NO_PUBLISHER_METADATA}': True + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': 9876543210, 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker)) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': 9876543210, 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker)) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_delete_statement_with_target_condition(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': [TargetWithCondition('Foo', '(target)-[:BAR]->(:Foo) AND target.name=\'foo_name\'')], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': [TargetWithCondition('BAR', '(start_node:Foo)-[target]->(end_node:Foo)')], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag < $marker + OR NOT EXISTS(target.published_tag)) AND (start_node:Foo)-[target]->(end_node:Foo) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_ms_to_expire_too_small(self) -> None: + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 24 * 60 * 60 * 100 - 10 + }) + + try: + task.init(job_config) + self.assertTrue(False, 'Should have failed with small TTL ') + except Exception: + pass + + with patch.object(GraphDatabase, 'driver'): + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 24 * 60 * 60 * 1000, + }) + task.init(job_config) + + def test_delete_dry_run(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + session_mock = mock_driver.return_value.session + + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'neo4j://example.com:7687', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.DRY_RUN}': True, + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + session_mock.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/test_base_job.py b/databuilder/tests/unit/test_base_job.py new file mode 100644 index 0000000000..a6d872e138 --- /dev/null +++ b/databuilder/tests/unit/test_base_job.py @@ -0,0 +1,173 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import shutil +import tempfile +import unittest +from typing import Any + +from mock import patch +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.extractor.base_extractor import Extractor +from databuilder.job.job import DefaultJob +from databuilder.loader.base_loader import Loader +from databuilder.task.task import DefaultTask +from databuilder.transformer.base_transformer import Transformer + +LOGGER = logging.getLogger(__name__) + + +class TestJob(unittest.TestCase): + + def setUp(self) -> None: + self.temp_dir_path = tempfile.mkdtemp() + self.dest_file_name = f'{self.temp_dir_path}/superhero.json' + self.conf = ConfigFactory.from_dict({'loader.superhero.dest_file': self.dest_file_name}) + + def tearDown(self) -> None: + shutil.rmtree(self.temp_dir_path) + + def test_job(self) -> None: + with patch("databuilder.job.job.StatsClient") as mock_statsd: + task = DefaultTask(SuperHeroExtractor(), + SuperHeroLoader(), + transformer=SuperHeroReverseNameTransformer()) + + job = DefaultJob(self.conf, task) + job.launch() + + expected_list = ['{"hero": "Super man", "name": "tneK kralC"}', + '{"hero": "Bat man", "name": "enyaW ecurB"}'] + with open(self.dest_file_name, 'r') as file: + for expected in expected_list: + actual = file.readline().rstrip('\n') + self.assertEqual(expected, actual) + self.assertFalse(file.readline()) + + self.assertEqual(mock_statsd.call_count, 0) + + +class TestJobNoTransform(unittest.TestCase): + + def setUp(self) -> None: + self.temp_dir_path = tempfile.mkdtemp() + self.dest_file_name = f'{self.temp_dir_path}/superhero.json' + self.conf = ConfigFactory.from_dict( + {'loader.superhero.dest_file': self.dest_file_name}) + + def tearDown(self) -> None: + shutil.rmtree(self.temp_dir_path) + + def test_job(self) -> None: + task = DefaultTask(SuperHeroExtractor(), SuperHeroLoader()) + + job = DefaultJob(self.conf, task) + job.launch() + + expected_list = ['{"hero": "Super man", "name": "Clark Kent"}', + '{"hero": "Bat man", "name": "Bruce Wayne"}'] + with open(self.dest_file_name, 'r') as file: + for expected in expected_list: + actual = file.readline().rstrip('\n') + self.assertEqual(expected, actual) + self.assertFalse(file.readline()) + + +class TestJobStatsd(unittest.TestCase): + + def setUp(self) -> None: + self.temp_dir_path = tempfile.mkdtemp() + self.dest_file_name = f'{self.temp_dir_path}/superhero.json' + self.conf = ConfigFactory.from_dict( + {'loader.superhero.dest_file': self.dest_file_name, + 'job.is_statsd_enabled': True, + 'job.identifier': 'foobar'}) + + def tearDown(self) -> None: + shutil.rmtree(self.temp_dir_path) + + def test_job(self) -> None: + with patch("databuilder.job.job.StatsClient") as mock_statsd: + task = DefaultTask(SuperHeroExtractor(), SuperHeroLoader()) + + job = DefaultJob(self.conf, task) + job.launch() + + expected_list = ['{"hero": "Super man", "name": "Clark Kent"}', + '{"hero": "Bat man", "name": "Bruce Wayne"}'] + with open(self.dest_file_name, 'r') as file: + for expected in expected_list: + actual = file.readline().rstrip('\n') + self.assertEqual(expected, actual) + self.assertFalse(file.readline()) + + self.assertEqual(mock_statsd.return_value.incr.call_count, 1) + + +class SuperHeroExtractor(Extractor): + def __init__(self) -> None: + pass + + def init(self, conf: ConfigTree) -> None: + self.records = [SuperHero(hero='Super man', name='Clark Kent'), + SuperHero(hero='Bat man', name='Bruce Wayne')] + self.iter = iter(self.records) + + def extract(self) -> Any: + try: + return next(self.iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return 'extractor.superhero' + + +class SuperHero: + def __init__(self, + hero: str, + name: str) -> None: + self.hero = hero + self.name = name + + def __repr__(self) -> str: + return f'SuperHero(hero={self.hero}, name={self.name})' + + +class SuperHeroReverseNameTransformer(Transformer): + def __init__(self) -> None: + pass + + def init(self, conf: ConfigTree) -> None: + pass + + def transform(self, record: Any) -> Any: + record.name = record.name[::-1] + return record + + def get_scope(self) -> str: + return 'transformer.superhero' + + +class SuperHeroLoader(Loader): + def init(self, conf: ConfigTree) -> None: + self.conf = conf + dest_file_path = self.conf.get_string('dest_file') + LOGGER.info('Loading to %s', dest_file_path) + self.dest_file_obj = open(self.conf.get_string('dest_file'), 'w') + + def load(self, record: Any) -> None: + rec = json.dumps(record.__dict__, sort_keys=True) + LOGGER.info('Writing record: %s', rec) + self.dest_file_obj.write(f'{rec}\n') + self.dest_file_obj.flush() + + def get_scope(self) -> str: + return 'loader.superhero' + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/__init__.py b/databuilder/tests/unit/transformer/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/transformer/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/transformer/test_bigquery_usage_transformer.py b/databuilder/tests/unit/transformer/test_bigquery_usage_transformer.py new file mode 100644 index 0000000000..6d5a7dfd58 --- /dev/null +++ b/databuilder/tests/unit/transformer/test_bigquery_usage_transformer.py @@ -0,0 +1,56 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.extractor.bigquery_usage_extractor import TableColumnUsageTuple +from databuilder.models.table_column_usage import TableColumnUsage +from databuilder.transformer.bigquery_usage_transformer import BigqueryUsageTransformer + + +class TestBigQueryUsageTransform(unittest.TestCase): + + DATABASE = 'bigquery' + CLUSTER = 'your-project-here' + DATASET = 'dataset' + TABLE = 'table' + COLUMN = '*' + EMAIL = 'your-user-here@test.com' + READ_COUNT = 305 + TABLE_KEY = 'bigquery://your-project-here.dataset/table' + + def test_transform_function(self) -> None: + config = ConfigFactory.from_dict({}) + + transformer = BigqueryUsageTransformer() + transformer.init(config) + + key = TableColumnUsageTuple(database=TestBigQueryUsageTransform.DATABASE, + cluster=TestBigQueryUsageTransform.CLUSTER, + schema=TestBigQueryUsageTransform.DATASET, + table=TestBigQueryUsageTransform.TABLE, + column=TestBigQueryUsageTransform.COLUMN, + email=TestBigQueryUsageTransform.EMAIL) + + t1 = (key, TestBigQueryUsageTransform.READ_COUNT) + xformed = transformer.transform(t1) + + assert xformed is not None + self.assertIsInstance(xformed, TableColumnUsage) + col_readers = list(xformed.col_readers) + self.assertEqual(len(col_readers), 1) + col_reader = col_readers[0] + self.assertEqual(col_reader.start_label, 'Table') + self.assertEqual(col_reader.start_key, TestBigQueryUsageTransform.TABLE_KEY) + self.assertEqual(col_reader.user_email, TestBigQueryUsageTransform.EMAIL) + self.assertEqual(col_reader.read_count, TestBigQueryUsageTransform.READ_COUNT) + + def test_scope(self) -> None: + config = ConfigFactory.from_dict({}) + + transformer = BigqueryUsageTransformer() + transformer.init(config) + + self.assertEqual(transformer.get_scope(), 'transformer.bigquery_usage') diff --git a/databuilder/tests/unit/transformer/test_chained_transformer.py b/databuilder/tests/unit/transformer/test_chained_transformer.py new file mode 100644 index 0000000000..e920ccfe29 --- /dev/null +++ b/databuilder/tests/unit/transformer/test_chained_transformer.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from mock import MagicMock +from pyhocon import ConfigFactory + +from databuilder.transformer.base_transformer import ChainedTransformer + + +class TestChainedTransformer(unittest.TestCase): + def test_init_not_called(self) -> None: + + mock_transformer1 = MagicMock() + mock_transformer1.transform.return_value = "foo" + mock_transformer2 = MagicMock() + mock_transformer2.transform.return_value = "bar" + + chained_transformer = ChainedTransformer( + transformers=[mock_transformer1, mock_transformer2] + ) + + config = ConfigFactory.from_dict({}) + chained_transformer.init(conf=config) + + next(chained_transformer.transform({"foo": "bar"})) + + mock_transformer1.init.assert_not_called() + mock_transformer1.transform.assert_called_once() + mock_transformer2.init.assert_not_called() + mock_transformer2.transform.assert_called_once() + + def test_init_called(self) -> None: + + mock_transformer1 = MagicMock() + mock_transformer1.get_scope.return_value = "foo" + mock_transformer1.transform.return_value = "foo" + mock_transformer2 = MagicMock() + mock_transformer2.get_scope.return_value = "bar" + mock_transformer2.transform.return_value = "bar" + + chained_transformer = ChainedTransformer( + transformers=[mock_transformer1, mock_transformer2], + is_init_transformers=True, + ) + + config = ConfigFactory.from_dict({}) + chained_transformer.init(conf=config) + + next(chained_transformer.transform({"foo": "bar"})) + + mock_transformer1.init.assert_called_once() + mock_transformer1.transform.assert_called_once() + mock_transformer2.init.assert_called_once() + mock_transformer2.transform.assert_called_once() + + def test_transformer_transforms(self) -> None: + + mock_transformer1 = MagicMock() + mock_transformer1.transform.side_effect = lambda s: s + "b" + mock_transformer2 = MagicMock() + mock_transformer2.transform.side_effect = lambda s: s + "c" + + chained_transformer = ChainedTransformer( + transformers=[mock_transformer1, mock_transformer2] + ) + + config = ConfigFactory.from_dict({}) + chained_transformer.init(conf=config) + + result = next(chained_transformer.transform("a")) + self.assertEqual(result, "abc") diff --git a/databuilder/tests/unit/transformer/test_complex_type_transformer.py b/databuilder/tests/unit/transformer/test_complex_type_transformer.py new file mode 100644 index 0000000000..391eba2f5c --- /dev/null +++ b/databuilder/tests/unit/transformer/test_complex_type_transformer.py @@ -0,0 +1,140 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import MagicMock, patch + +from pyhocon import ConfigFactory + +from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.models.type_metadata import ( + ArrayTypeMetadata, ScalarTypeMetadata, TypeMetadata, +) +from databuilder.transformer.complex_type_transformer import PARSING_FUNCTION, ComplexTypeTransformer + + +class TestComplexTypeTransformer(unittest.TestCase): + def test_invalid_parsing_function_missing_module(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'invalid_function', + }) + with self.assertRaises(Exception): + transformer.init(conf=config) + + def test_invalid_parsing_function_invalid_module(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'invalid_module.invalid_function', + }) + with self.assertRaises(Exception): + transformer.init(conf=config) + + def test_invalid_parsing_function_invalid_function(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'databuilder.utils.hive_complex_type_parser.invalid_function', + }) + with self.assertRaises(Exception): + transformer.init(conf=config) + + def test_hive_parser_with_failures(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'databuilder.utils.hive_complex_type_parser.parse_hive_type', + }) + transformer.init(conf=config) + + column = ColumnMetadata('col1', 'array type', 'array>', 0) + table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema', + 'test_table', + 'test_table', + [column] + ) + + default_scalar_type = ScalarTypeMetadata(name='col1', + parent=column, + type_str='array>') + + with patch.object(transformer, '_parsing_function') as mock: + mock.side_effect = MagicMock(side_effect=Exception('Could not parse')) + + result = transformer.transform(table_metadata) + + self.assertEqual(transformer.success_count, 0) + self.assertEqual(transformer.failure_count, 1) + for actual in result.columns: + self.assertEqual(actual.get_type_metadata(), default_scalar_type) + + def test_hive_parser_usage(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'databuilder.utils.hive_complex_type_parser.parse_hive_type', + }) + transformer.init(conf=config) + + column = ColumnMetadata('col1', 'array type', 'array>', 0) + table_metadata = TableMetadata( + 'hive', + 'gold', + 'test_schema', + 'test_table', + 'test_table', + [column] + ) + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array>') + inner_array = ArrayTypeMetadata(name='_inner_', + parent=array_type, + type_str='array') + + array_type.array_inner_type = inner_array + + result = transformer.transform(table_metadata) + + for actual in result.columns: + self.assertTrue(isinstance(actual.get_type_metadata(), TypeMetadata)) + self.assertEqual(actual.get_type_metadata(), array_type) + self.assertEqual(transformer.success_count, 1) + self.assertEqual(transformer.failure_count, 0) + + def test_trino_parser_usage(self) -> None: + transformer = ComplexTypeTransformer() + config = ConfigFactory.from_dict({ + PARSING_FUNCTION: 'databuilder.utils.trino_complex_type_parser.parse_trino_type', + }) + transformer.init(conf=config) + + column = ColumnMetadata('col1', 'array type', 'array(array(int))', 0) + table_metadata = TableMetadata( + 'trino', + 'gold', + 'test_schema', + 'test_table', + 'test_table', + [column] + ) + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array(array(int))') + inner_array = ArrayTypeMetadata(name='_inner_', + parent=array_type, + type_str='array(int)') + + array_type.array_inner_type = inner_array + + result = transformer.transform(table_metadata) + + for actual in result.columns: + self.assertTrue(isinstance(actual.get_type_metadata(), TypeMetadata)) + self.assertEqual(actual.get_type_metadata(), array_type) + self.assertEqual(transformer.success_count, 1) + self.assertEqual(transformer.failure_count, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/test_dict_to_model_transformer.py b/databuilder/tests/unit/transformer/test_dict_to_model_transformer.py new file mode 100644 index 0000000000..53857c3436 --- /dev/null +++ b/databuilder/tests/unit/transformer/test_dict_to_model_transformer.py @@ -0,0 +1,41 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.models.dashboard.dashboard_execution import DashboardExecution +from databuilder.transformer.dict_to_model import MODEL_CLASS, DictToModel + + +class TestDictToModel(unittest.TestCase): + + def test_conversion(self) -> None: + + transformer = DictToModel() + config = ConfigFactory.from_dict({ + MODEL_CLASS: 'databuilder.models.dashboard.dashboard_execution.DashboardExecution', + }) + transformer.init(conf=config) + + actual = transformer.transform( + { + 'dashboard_group_id': 'foo', + 'dashboard_id': 'bar', + 'execution_timestamp': 123456789, + 'execution_state': 'succeed', + 'product': 'mode', + 'cluster': 'gold' + } + ) + + self.assertTrue(isinstance(actual, DashboardExecution)) + self.assertEqual(actual.__repr__(), DashboardExecution( + dashboard_group_id='foo', + dashboard_id='bar', + execution_timestamp=123456789, + execution_state='succeed', + product='mode', + cluster='gold' + ).__repr__()) diff --git a/databuilder/tests/unit/transformer/test_regex_str_replace_transformer.py b/databuilder/tests/unit/transformer/test_regex_str_replace_transformer.py new file mode 100644 index 0000000000..320c20361a --- /dev/null +++ b/databuilder/tests/unit/transformer/test_regex_str_replace_transformer.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Any + +from pyhocon import ConfigFactory + +from databuilder.transformer.regex_str_replace_transformer import ( + ATTRIBUTE_NAME, REGEX_REPLACE_TUPLE_LIST, RegexStrReplaceTransformer, +) + + +class TestRegexReplacement(unittest.TestCase): + + def test(self) -> None: + transformer = self._default_test_transformer() + + foo = Foo('abc') + actual = transformer.transform(foo) + + self.assertEqual('bba', actual.val) + + def test_numeric_val(self) -> None: + transformer = self._default_test_transformer() + + foo = Foo(6) + actual = transformer.transform(foo) + + self.assertEqual(6, actual.val) + + def test_none_val(self) -> None: + transformer = self._default_test_transformer() + + foo = Foo(None) + actual = transformer.transform(foo) + + self.assertEqual(None, actual.val) + + def _default_test_transformer(self) -> RegexStrReplaceTransformer: + config = ConfigFactory.from_dict({ + REGEX_REPLACE_TUPLE_LIST: [('a', 'b'), ('c', 'a')], + ATTRIBUTE_NAME: 'val' + }) + + transformer = RegexStrReplaceTransformer() + transformer.init(config) + + return transformer + + def test_dict_replace(self) -> None: + config = ConfigFactory.from_dict({ + REGEX_REPLACE_TUPLE_LIST: [('\\', '\\\\')], + ATTRIBUTE_NAME: 'val' + }) + + transformer = RegexStrReplaceTransformer() + transformer.init(config) + + d = {'val': '\\'} + + actual = transformer.transform(d) + + self.assertEqual({'val': '\\\\'}, actual) + + +class Foo(object): + def __init__(self, val: Any) -> None: + self.val = val + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/test_remove_field_transformer.py b/databuilder/tests/unit/transformer/test_remove_field_transformer.py new file mode 100644 index 0000000000..0ce28e5ff0 --- /dev/null +++ b/databuilder/tests/unit/transformer/test_remove_field_transformer.py @@ -0,0 +1,52 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.transformer.remove_field_transformer import FIELD_NAMES, RemoveFieldTransformer + + +class TestRemoveFieldTransformer(unittest.TestCase): + + def test_conversion(self) -> None: + + transformer = RemoveFieldTransformer() + config = ConfigFactory.from_dict({ + FIELD_NAMES: ['foo', 'bar'], + }) + transformer.init(conf=config) + + actual = transformer.transform({ + 'foo': 'foo_val', + 'bar': 'bar_val', + 'baz': 'baz_val', + }) + expected = { + 'baz': 'baz_val' + } + self.assertDictEqual(expected, actual) + + def test_conversion_missing_field(self) -> None: + + transformer = RemoveFieldTransformer() + config = ConfigFactory.from_dict({ + FIELD_NAMES: ['foo', 'bar'], + }) + transformer.init(conf=config) + + actual = transformer.transform({ + 'foo': 'foo_val', + 'baz': 'baz_val', + 'john': 'doe', + }) + expected = { + 'baz': 'baz_val', + 'john': 'doe' + } + self.assertDictEqual(expected, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/test_table_tag_transformer.py b/databuilder/tests/unit/transformer/test_table_tag_transformer.py new file mode 100644 index 0000000000..4f8efda0e5 --- /dev/null +++ b/databuilder/tests/unit/transformer/test_table_tag_transformer.py @@ -0,0 +1,80 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.models.table_metadata import TableMetadata +from databuilder.transformer.table_tag_transformer import TableTagTransformer + + +class TestTableTagTransformer(unittest.TestCase): + def test_single_tag(self) -> None: + transformer = TableTagTransformer() + config = ConfigFactory.from_dict({ + TableTagTransformer.TAGS: 'foo', + }) + transformer.init(conf=config) + + result = transformer.transform(TableMetadata( + database='test_db', + cluster='test_cluster', + schema='test_schema', + name='test_table', + description='', + )) + + self.assertEqual(result.tags, ['foo']) + + def test_multiple_tags_comma_delimited(self) -> None: + transformer = TableTagTransformer() + config = ConfigFactory.from_dict({ + TableTagTransformer.TAGS: 'foo,bar', + }) + transformer.init(conf=config) + + result = transformer.transform(TableMetadata( + database='test_db', + cluster='test_cluster', + schema='test_schema', + name='test_table', + description='', + )) + + self.assertEqual(result.tags, ['foo', 'bar']) + + def test_add_tag_to_existing_tags(self) -> None: + transformer = TableTagTransformer() + config = ConfigFactory.from_dict({ + TableTagTransformer.TAGS: 'baz', + }) + transformer.init(conf=config) + + result = transformer.transform(TableMetadata( + database='test_db', + cluster='test_cluster', + schema='test_schema', + name='test_table', + description='', + tags='foo,bar', + )) + self.assertEqual(result.tags, ['foo', 'bar', 'baz']) + + def test_tags_not_added_to_other_objects(self) -> None: + transformer = TableTagTransformer() + config = ConfigFactory.from_dict({ + TableTagTransformer.TAGS: 'new_tag', + }) + transformer.init(conf=config) + + class NotATable(): + tags = 'existing_tag' + + result = transformer.transform(NotATable()) + + self.assertEqual(result.tags, 'existing_tag') + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/test_template_variable_substitution_transformer.py b/databuilder/tests/unit/transformer/test_template_variable_substitution_transformer.py new file mode 100644 index 0000000000..0e10428fdc --- /dev/null +++ b/databuilder/tests/unit/transformer/test_template_variable_substitution_transformer.py @@ -0,0 +1,33 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.transformer.template_variable_substitution_transformer import ( + FIELD_NAME, TEMPLATE, TemplateVariableSubstitutionTransformer, +) + + +class TestTemplateVariableSubstitutionTransformer(unittest.TestCase): + + def test_conversion(self) -> None: + + transformer = TemplateVariableSubstitutionTransformer() + config = ConfigFactory.from_dict({ + FIELD_NAME: 'baz', + TEMPLATE: 'Hello {foo}' + }) + transformer.init(conf=config) + + actual = transformer.transform({'foo': 'bar'}) + expected = { + 'foo': 'bar', + 'baz': 'Hello bar' + } + self.assertDictEqual(expected, actual) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/transformer/test_timestamp_string_to_epoch_transformer.py b/databuilder/tests/unit/transformer/test_timestamp_string_to_epoch_transformer.py new file mode 100644 index 0000000000..c87fe44e1f --- /dev/null +++ b/databuilder/tests/unit/transformer/test_timestamp_string_to_epoch_transformer.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyhocon import ConfigFactory + +from databuilder.transformer.timestamp_string_to_epoch import ( + FIELD_NAME, TIMESTAMP_FORMAT, TimestampStringToEpoch, +) + + +class TestTimestampStrToEpoch(unittest.TestCase): + + def test_conversion(self) -> None: + + transformer = TimestampStringToEpoch() + config = ConfigFactory.from_dict({ + FIELD_NAME: 'foo', + }) + transformer.init(conf=config) + + actual = transformer.transform({'foo': '2020-02-19T19:52:33.1Z'}) + self.assertDictEqual({'foo': 1582141953}, actual) + + def test_conversion_with_format(self) -> None: + + transformer = TimestampStringToEpoch() + config = ConfigFactory.from_dict({ + FIELD_NAME: 'foo', + TIMESTAMP_FORMAT: '%Y-%m-%dT%H:%M:%SZ' + }) + transformer.init(conf=config) + + actual = transformer.transform({'foo': '2020-02-19T19:52:33Z'}) + self.assertDictEqual({'foo': 1582141953}, actual) + + def test_invalid_timestamp(self) -> None: + transformer = TimestampStringToEpoch() + config = ConfigFactory.from_dict({ + FIELD_NAME: 'foo', + }) + transformer.init(conf=config) + actual = transformer.transform({'foo': '165de33266d4'}) + self.assertEqual(actual['foo'], 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/usage/__init__.py b/databuilder/tests/unit/usage/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/usage/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/usage/presto/__init__.py b/databuilder/tests/unit/usage/presto/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/usage/presto/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/utils/__init__.py b/databuilder/tests/unit/utils/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/databuilder/tests/unit/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/tests/unit/utils/test_hive_complex_type_parser.py b/databuilder/tests/unit/utils/test_hive_complex_type_parser.py new file mode 100644 index 0000000000..0b3a45ef13 --- /dev/null +++ b/databuilder/tests/unit/utils/test_hive_complex_type_parser.py @@ -0,0 +1,300 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyparsing import ParseException + +from databuilder.models.table_metadata import ColumnMetadata +from databuilder.models.type_metadata import ( + ArrayTypeMetadata, MapTypeMetadata, ScalarTypeMetadata, StructTypeMetadata, +) +from databuilder.utils.hive_complex_type_parser import parse_hive_type + + +class TestHiveComplexTypeParser(unittest.TestCase): + def setUp(self) -> None: + self.column_key = 'hive://gold.test_schema/test_table/col1' + + def test_transform_no_complex_type(self) -> None: + column = ColumnMetadata('col1', None, 'int', 0) + column.set_column_key(self.column_key) + + scalar_type = ScalarTypeMetadata(name='col1', + parent=column, + type_str='int') + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, scalar_type) + + def test_transform_array_type(self) -> None: + column = ColumnMetadata('col1', None, 'array>', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array>') + inner_array = ArrayTypeMetadata(name='_inner_', + parent=array_type, + type_str='array') + + array_type.array_inner_type = inner_array + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_array_map_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'array>', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array>') + inner_map = MapTypeMetadata(name='_inner_', + parent=array_type, + type_str='map') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map, + type_str='string') + inner_scalar = ScalarTypeMetadata(name='_map_value', + parent=inner_map, + type_str='int') + + array_type.array_inner_type = inner_map + inner_map.map_key_type = inner_map_key + inner_map.map_value_type = inner_scalar + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_array_struct_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'array>', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array>') + inner_struct = StructTypeMetadata(name='_inner_', + parent=array_type, + type_str='struct') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=inner_struct, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=inner_struct, + type_str='int') + + array_type.array_inner_type = inner_struct + inner_struct.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_map_type(self) -> None: + column = ColumnMetadata('col1', None, 'map>', 0) + column.set_column_key(self.column_key) + + map_type = MapTypeMetadata(name='col1', + parent=column, + type_str='map>') + map_key = ScalarTypeMetadata(name='_map_key', + parent=map_type, + type_str='string') + map_value = MapTypeMetadata(name='_map_value', + parent=map_type, + type_str='map') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=map_value, + type_str='string') + inner_scalar = ScalarTypeMetadata(name='_map_value', + parent=map_value, + type_str='int') + + map_type.map_key_type = map_key + map_type.map_value_type = map_value + map_value.map_key_type = inner_map_key + map_value.map_value_type = inner_scalar + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, map_type) + + def test_transform_map_struct_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'map>', 0) + column.set_column_key(self.column_key) + + map_type = MapTypeMetadata(name='col1', + parent=column, + type_str='map>') + map_key = ScalarTypeMetadata(name='_map_key', + parent=map_type, + type_str='string') + inner_struct = StructTypeMetadata(name='_map_value', + parent=map_type, + type_str='struct') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=inner_struct, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=inner_struct, + type_str='int') + + map_type.map_key_type = map_key + map_type.map_value_type = inner_struct + inner_struct.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, map_type) + + def test_transform_struct_type(self) -> None: + column = ColumnMetadata('col1', None, 'struct', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='struct') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=struct_type, + type_str='int') + + struct_type.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_struct_map_array_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'struct>,nest2:array>', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='struct>,nest2:array>') + inner_map = MapTypeMetadata(name='nest1', + parent=struct_type, + type_str='map>') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map, + type_str='string') + inner_map_array = ArrayTypeMetadata(name='_map_value', + parent=inner_map, + type_str='array') + inner_struct_array = ArrayTypeMetadata(name='nest2', + parent=struct_type, + type_str='array') + + struct_type.struct_items = {'nest1': inner_map, 'nest2': inner_struct_array} + inner_map.map_key_type = inner_map_key + inner_map.map_value_type = inner_map_array + inner_map.sort_order = 0 + inner_struct_array.sort_order = 1 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_non_alpha_only_types(self) -> None: + column = ColumnMetadata('col1', None, 'struct,' + 'nest5:interval_day_time>', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='struct,' + 'nest5:interval_day_time>') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='decimal(10,2)') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=struct_type, + type_str='double precision') + inner_scalar_nest3 = ScalarTypeMetadata(name='nest3', + parent=struct_type, + type_str='varchar(32)') + inner_map_nest4 = MapTypeMetadata(name='nest4', + parent=struct_type, + type_str='map') + inner_map_nest4_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map_nest4, + type_str='varchar(32)') + inner_map_nest4_value = ScalarTypeMetadata(name='_map_value', + parent=inner_map_nest4, + type_str='decimal(10,2)') + inner_scalar_nest5 = ScalarTypeMetadata(name='nest5', + parent=struct_type, + type_str='interval_day_time') + + struct_type.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2, + 'nest3': inner_scalar_nest3, 'nest4': inner_map_nest4, + 'nest5': inner_scalar_nest5} + inner_map_nest4.map_key_type = inner_map_nest4_key + inner_map_nest4.map_value_type = inner_map_nest4_value + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + inner_scalar_nest3.sort_order = 2 + inner_map_nest4.sort_order = 3 + inner_scalar_nest5.sort_order = 4 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_union_as_scalar_type(self) -> None: + column = ColumnMetadata('col1', None, 'uniontype>', 0) + column.set_column_key(self.column_key) + + struct_type = ScalarTypeMetadata(name='col1', + parent=column, + type_str='uniontype>') + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_union_as_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'struct>,' + 'nest2:uniontype>', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='struct>,' + 'nest2:uniontype>') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='uniontype>') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=struct_type, + type_str='uniontype') + + struct_type.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_hive_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_invalid_array_inner_type(self) -> None: + column = ColumnMetadata('col1', None, 'array>', 0) + column.set_column_key(self.column_key) + + with self.assertRaises(ParseException): + parse_hive_type(column.type, column.name, column) + + def test_transform_invalid_struct_inner_type(self) -> None: + column = ColumnMetadata('col1', None, 'struct>', 0) + column.set_column_key(self.column_key) + + with self.assertRaises(ParseException): + parse_hive_type(column.type, column.name, column) + + +if __name__ == '__main__': + unittest.main() diff --git a/databuilder/tests/unit/utils/test_trino_complex_type_parser.py b/databuilder/tests/unit/utils/test_trino_complex_type_parser.py new file mode 100644 index 0000000000..9877966dea --- /dev/null +++ b/databuilder/tests/unit/utils/test_trino_complex_type_parser.py @@ -0,0 +1,302 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pyparsing import ParseException + +from databuilder.models.table_metadata import ColumnMetadata +from databuilder.models.type_metadata import ( + ArrayTypeMetadata, MapTypeMetadata, ScalarTypeMetadata, StructTypeMetadata, +) +from databuilder.utils.trino_complex_type_parser import parse_trino_type + + +class TestTrinoComplexTypeParser(unittest.TestCase): + def setUp(self) -> None: + self.column_key = 'trino://gold.test_schema/test_table/col1' + + def test_transform_no_complex_type(self) -> None: + column = ColumnMetadata('col1', None, 'int', 0) + column.set_column_key(self.column_key) + + scalar_type = ScalarTypeMetadata(name='col1', + parent=column, + type_str='int') + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, scalar_type) + + def test_transform_array_type(self) -> None: + column = ColumnMetadata('col1', None, 'array(array(int))', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array(array(int))') + inner_array = ArrayTypeMetadata(name='_inner_', + parent=array_type, + type_str='array(int)') + + array_type.array_inner_type = inner_array + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_array_map_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'array(map(string,int))', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array(map(string,int))') + inner_map = MapTypeMetadata(name='_inner_', + parent=array_type, + type_str='map(string,int)') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map, + type_str='string') + inner_scalar = ScalarTypeMetadata(name='_map_value', + parent=inner_map, + type_str='int') + + array_type.array_inner_type = inner_map + inner_map.map_key_type = inner_map_key + inner_map.map_value_type = inner_scalar + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_array_struct_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'array(row(nest1 int,nest2 int))', 0) + column.set_column_key(self.column_key) + + array_type = ArrayTypeMetadata(name='col1', + parent=column, + type_str='array(row(nest1 int,nest2 int))') + inner_struct = StructTypeMetadata(name='_inner_', + parent=array_type, + type_str='row(nest1 int,nest2 int)') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=inner_struct, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=inner_struct, + type_str='int') + + array_type.array_inner_type = inner_struct + inner_struct.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, array_type) + + def test_transform_map_type(self) -> None: + column = ColumnMetadata('col1', None, 'map(string,map(string,int))', 0) + column.set_column_key(self.column_key) + + map_type = MapTypeMetadata(name='col1', + parent=column, + type_str='map(string,map(string,int))') + map_key = ScalarTypeMetadata(name='_map_key', + parent=map_type, + type_str='string') + map_value = MapTypeMetadata(name='_map_value', + parent=map_type, + type_str='map(string,int)') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=map_value, + type_str='string') + inner_scalar = ScalarTypeMetadata(name='_map_value', + parent=map_value, + type_str='int') + + map_type.map_key_type = map_key + map_type.map_value_type = map_value + map_value.map_key_type = inner_map_key + map_value.map_value_type = inner_scalar + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, map_type) + + def test_transform_map_struct_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'map(string,row(nest1 int,nest2 int))', 0) + column.set_column_key(self.column_key) + + map_type = MapTypeMetadata(name='col1', + parent=column, + type_str='map(string,row(nest1 int,nest2 int))') + map_key = ScalarTypeMetadata(name='_map_key', + parent=map_type, + type_str='string') + inner_struct = StructTypeMetadata(name='_map_value', + parent=map_type, + type_str='row(nest1 int,nest2 int)') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=inner_struct, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=inner_struct, + type_str='int') + + map_type.map_key_type = map_key + map_type.map_value_type = inner_struct + inner_struct.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, map_type) + + def test_transform_struct_type(self) -> None: + column = ColumnMetadata('col1', None, 'row(nest1 int,nest2 int)', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='row(nest1 int,nest2 int)') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='int') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=struct_type, + type_str='int') + + struct_type.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2} + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_struct_map_array_nested_type(self) -> None: + column = ColumnMetadata('col1', None, 'row(nest1 map(string,array(int)),nest2 array(string))', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='row(nest1 map(string,array(int)),nest2 array(string))') + inner_map = MapTypeMetadata(name='nest1', + parent=struct_type, + type_str='map(string,array(int))') + inner_map_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map, + type_str='string') + inner_map_array = ArrayTypeMetadata(name='_map_value', + parent=inner_map, + type_str='array(int)') + inner_struct_array = ArrayTypeMetadata(name='nest2', + parent=struct_type, + type_str='array(string)') + + struct_type.struct_items = {'nest1': inner_map, 'nest2': inner_struct_array} + inner_map.map_key_type = inner_map_key + inner_map.map_value_type = inner_map_array + inner_map.sort_order = 0 + inner_struct_array.sort_order = 1 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_struct_nested_type_with_quoted_names(self) -> None: + column = ColumnMetadata('col1', None, 'row("nest1" varchar,"nest2" row("nest3" varchar,' + '"nest4" timestamp(3),"nest5" timestamp(3)))', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='row(nest1 varchar,nest2 row(nest3 varchar,' + 'nest4 timestamp(3),nest5 timestamp(3)))') + inner_scalar = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='varchar') + inner_struct = StructTypeMetadata(name='nest2', + parent=struct_type, + type_str='row(nest3 varchar,nest4 timestamp(3),nest5 timestamp(3))') + inner_scalar_1 = ScalarTypeMetadata(name='nest3', + parent=inner_struct, + type_str='varchar') + inner_scalar_2 = ScalarTypeMetadata(name='nest4', + parent=inner_struct, + type_str='timestamp(3)') + inner_scalar_3 = ScalarTypeMetadata(name='nest5', + parent=inner_struct, + type_str='timestamp(3)') + + struct_type.struct_items = {'nest1': inner_scalar, 'nest2': inner_struct} + inner_struct.struct_items = {'nest3': inner_scalar_1, 'nest4': inner_scalar_2, 'nest5': inner_scalar_3} + inner_scalar.sort_order = 0 + inner_struct.sort_order = 1 + inner_scalar_1.sort_order = 0 + inner_scalar_2.sort_order = 1 + inner_scalar_3.sort_order = 2 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_non_alpha_only_types(self) -> None: + column = ColumnMetadata('col1', None, 'row(nest1 decimal(10,2),nest2 double precision,' + 'nest3 varchar(32),nest4 map(varchar(32),decimal(10,2)),' + 'nest5 interval_day_time)', 0) + column.set_column_key(self.column_key) + + struct_type = StructTypeMetadata(name='col1', + parent=column, + type_str='row(nest1 decimal(10,2),nest2 double precision,' + 'nest3 varchar(32),nest4 map(varchar(32),decimal(10,2)),' + 'nest5 interval_day_time)') + inner_scalar_nest1 = ScalarTypeMetadata(name='nest1', + parent=struct_type, + type_str='decimal(10,2)') + inner_scalar_nest2 = ScalarTypeMetadata(name='nest2', + parent=struct_type, + type_str='double precision') + inner_scalar_nest3 = ScalarTypeMetadata(name='nest3', + parent=struct_type, + type_str='varchar(32)') + inner_map_nest4 = MapTypeMetadata(name='nest4', + parent=struct_type, + type_str='map(varchar(32),decimal(10,2))') + inner_map_nest4_key = ScalarTypeMetadata(name='_map_key', + parent=inner_map_nest4, + type_str='varchar(32)') + inner_map_nest4_value = ScalarTypeMetadata(name='_map_value', + parent=inner_map_nest4, + type_str='decimal(10,2)') + inner_scalar_nest5 = ScalarTypeMetadata(name='nest5', + parent=struct_type, + type_str='interval_day_time') + + struct_type.struct_items = {'nest1': inner_scalar_nest1, 'nest2': inner_scalar_nest2, + 'nest3': inner_scalar_nest3, 'nest4': inner_map_nest4, + 'nest5': inner_scalar_nest5} + inner_map_nest4.map_key_type = inner_map_nest4_key + inner_map_nest4.map_value_type = inner_map_nest4_value + inner_scalar_nest1.sort_order = 0 + inner_scalar_nest2.sort_order = 1 + inner_scalar_nest3.sort_order = 2 + inner_map_nest4.sort_order = 3 + inner_scalar_nest5.sort_order = 4 + + actual = parse_trino_type(column.type, column.name, column) + self.assertEqual(actual, struct_type) + + def test_transform_invalid_array_inner_type(self) -> None: + column = ColumnMetadata('col1', None, 'array(array(int*))', 0) + column.set_column_key(self.column_key) + + with self.assertRaises(ParseException): + parse_trino_type(column.type, column.name, column) + + def test_transform_invalid_struct_inner_type(self) -> None: + column = ColumnMetadata('col1', None, 'row(nest1 varchar(256)å,' + 'nest2 (derived from deserializer))', 0) + column.set_column_key(self.column_key) + + with self.assertRaises(ParseException): + parse_trino_type(column.type, column.name, column) + + +if __name__ == '__main__': + unittest.main() diff --git a/deployment-best-practices/index.html b/deployment-best-practices/index.html new file mode 100644 index 0000000000..6dddd7c550 --- /dev/null +++ b/deployment-best-practices/index.html @@ -0,0 +1,1463 @@ + + + + + + + + + + + + + + + + + + + + + + + + Deployment best practices - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + + + +

Amundsen allows for many modifications, and many require code-level modifications. Until we put together a “paved path” suggestion on how to manage such a set-up, for now we will document what community members are doing independenly. If you have a production Amundsen deployment, please edit this doc to describe your setup.

+

Notes from community meeting 2020-12-03

+

These notes are from 2 companies a community round-table: https://www.youtube.com/watch?v=gVf7S98bnyg

+

Brex

+
    +
  • What modifications have you made?
      +
    • We’ve added backups
    • +
    • We wanted table descriptions to come solely from code.
    • +
    +
  • +
  • How do you deploy/secure Amundsen?
      +
    • Our hosting is behind VPN, use OIDC
    • +
    +
  • +
  • What do you use Amundsen for?
      +
    • Our primary use case: if I change this table, what dashboards will it break.
    • +
    • We also do PII tagging
    • +
    • ETL pipeline puts docs in Snowflake
    • +
    +
  • +
+

REA group

+
    +
  • Why did you choose Amundsen?
      +
    • We don’t have data stewards or formal roles who work on documentation. We liked that Amundsen didn’t rely on curated/manual documentation.
    • +
    • Google Data Catalog doesn’t allow you to search for data that you don’t have access to.
    • +
    • Things that we considered in other vendors - business metric glossary, column level lineage.
    • +
    +
  • +
  • How do you deploy Amundsen?
      +
    • Deployment on ECS. Built docker images on our own.
    • +
    • Deployment is done so that metadata is not lost. Looked into backing metadata in AWS, but decided not to. Instead use block storage so even if the instance goes down, the metadata is still there.
    • +
    • We only index prod data sets.
    • +
    • We don’t consider Amundsen as a source of truth. Thus, we don’t let people to enable update descriptions.
    • +
    • ETL indexer gets descriptions from BQ and puts it into Amundsen.
    • +
    • Postgres/source tables need some work to get descriptions from Go into Amundsen.
    • +
    +
  • +
  • Some changes we’d like to make:
      +
    • Authentication and Authorization
        +
      • Workflow for requesting access to data you can’t already access: right now we don’t have a workflow for requesting access that’s connected to Amundsen. Seems like an area of investment.
      • +
      +
    • +
    • Data Lineage
    • +
    • Business metrics glossary
    • +
    +
  • +
  • Q&A
      +
    • Why build your own images?
        +
      • Want to make sure system image and code running on the image should be tightly controlled. Patch over the specific files on top of Amundsen upstream code. Don’t fork right now. We chose to patch and not fork.
      • +
      +
    • +
    +
  • +
  • What was the process of getting alpha users onboard and getting feedback?
      +
    • Chose ~8 people who had different roles and different tenure. Then did UX interviews.
    • +
    +
  • +
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/developer_guide/index.html b/developer_guide/index.html new file mode 100644 index 0000000000..cc6fb47386 --- /dev/null +++ b/developer_guide/index.html @@ -0,0 +1,2019 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + Overview - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + + + + + +
+
+ + + +  + + + + + + +

Developer Guide

+

This repository uses git submodules to link the code for all of Amundsen’s libraries into a central location. This document offers guidance on how to develop locally with this setup.

+

This workflow leverages docker and docker-compose in a very similar manner to our installation documentation, to spin up instances of all 3 of Amundsen’s services connected with an instances of Neo4j and ElasticSearch which ingest dummy data.

+

Cloning the Repository

+

If cloning the repository for the first time, run the following command to clone the repository and pull the submodules:

+
$ git clone --recursive git@github.com:amundsen-io/amundsen.git
+
+

If you have already cloned the repository but your submodules are empty, from your cloned amundsen directory run:

+
$ git submodule init
+$ git submodule update
+
+

After cloning the repository you can change directories into any of the upstream folders and work in those directories as you normally would. You will have full access to all of the git features, and working in the upstream directories will function the same as if you were working in a cloned version of that repository.

+

Local Development

+

Ensure you have the latest code

+

Beyond running git pull origin master in your local amundsen directory, the submodules for our libraries also have to be manually updated to point to the latest versions of each libraries’ code. When creating a new branch on amundsen to begin local work, ensure your local submodules are pointing to the latest code for each library by running:

+
$ git submodule update --remote
+
+

Building local changes

+
    +
  1. +

    First, be sure that you have first followed the installation documentation and can spin up a default version of Amundsen without any issues. If you have already completed this step, be sure to have stopped and removed those containers by running: +

    $ docker-compose -f docker-amundsen.yml down
    +

    +
  2. +
  3. +

    Launch the containers needed for local development (the -d option launches in background) : +

    $ docker-compose -f docker-amundsen-local.yml up -d
    +

    +
  4. +
  5. +

    After making local changes rebuild and relaunch modified containers: +

    $ docker-compose -f docker-amundsen-local.yml build \
    +  && docker-compose -f docker-amundsen-local.yml up -d
    +

    +
  6. +
  7. +

    Optionally, to still tail logs, in a different terminal you can: +

    $ docker-compose -f docker-amundsen-local.yml logs --tail=3 -f
    +## - or just tail single container(s):
    +$ docker logs amundsenmetadata --tail 10 -f
    +

    +
  8. +
+

Local data

+

Local data is persisted under .local/ (at the root of the project), clean up the following directories to reset the databases:

+
#  reset elasticsearch
+rm -rf .local/elasticsearch
+
+#  reset neo4j
+rm -rf .local/neo4j
+
+

Troubleshooting

+
    +
  1. If you have made a change in amundsen/amundsenfrontendlibrary and do not see your changes, this could be due to your browser’s caching behaviors. Either execute a hard refresh (recommended) or clear your browser cache (last resort).
  2. +
+

Testing Amundsen frontend locally

+

Amundsen has an instruction regarding local frontend launch here

+

Here are some additional changes you might need for windows (OS Win 10):

+
    +
  • amundsen_application/config.py, set LOCAL_HOST = ‘127.0.0.1’
  • +
  • amundsen_application/wsgi.py, set host=‘127.0.0.1’ + (for other microservices also need to change port here because the default is 5000)
  • +
+

(using that approach you can run locally another microservices as well if needed)

+

Once you have a running frontend microservice, the rest of Amundsen components can be launched with docker-compose +from the root Amundsen project (don’t forget to remove frontend microservice section from docker-amundsen.yml): +docker-compose -f docker-amundsen.yml up +https://github.com/amundsen-io/amundsen/blob/main/docs/installation.md

+

Developing Dockerbuild file

+

When making edits to Dockerbuild file (docker-amundsen-local.yml) it is good to see what you are getting wrong locally. +To do that you build it docker build .

+

And then the output should include a line like so at the step right before it failed:

+
Step 3/20 : RUN git clone --recursive git://github.com/amundsen-io/amundsenfrontendlibrary.git  && cd amundsenfrontendlibrary  && git submodule foreach git pull origin master
+ ---> Using cache
+ ---> ec052612747e
+
+

You can then launch a container from this image like so

+
docker container run -it --name=debug ec052612747e /bin/sh
+
+

Building and Testing Amundsen Frontend Docker Image (or any other service)

+
    +
  1. Build your image +docker build --no-cache . it is recommended that you use –no-cache so you aren’t accidentally using an old version of an image.
  2. +
  3. Determine the hash of your images by running docker images and getting the id of your most recent image
  4. +
  5. Go to your locally cloned amundsen repo and edit the docker compose file “docker-amundsen.yml” to have +the amundsenfrontend image point to the hash of the image that you built
  6. +
+
  amundsenfrontend:
+      #image: amundsendev/amundsen-frontend:1.0.9
+      #image: 1234.dkr.ecr.us-west-2.amazonaws.com/edmunds/amundsen-frontend:2020-01-21
+      image: 0312d0ac3938
+
+

Pushing image to ECR and using in K8s

+

Assumptions:

+
    +
  • You have an aws account
  • +
  • +

    You have aws command line set up and ready to go

    +
  • +
  • +

    Choose an ECR repository you’d like to push to (or create a new one) +https://us-west-2.console.aws.amazon.com/ecr/repositories

    +
  • +
  • +

    Click onto repository name and open “View push commands” cheat sheet +2b. Login

    +

    it would look something like this:

    +

    aws ecr get-login --no-include-email --region us-west-2 +Then execute what is returned by above

    +
  • +
  • +

    Follow the instructions (you may need to install first AWS CLI, aws-okta and configure your AWS credentials if you haven’t done it before) +Given image name is amundsen-frontend, build, tag and push commands will be the following: +Here you can see the tag is YYYY-MM-dd but you should choose whatever you like. +

    docker build -t amundsen-frontend:{YYYY-MM-dd} .
    +docker tag amundsen-frontend:{YYYY-MM-dd} <?>.dkr.ecr.<?>.amazonaws.com/amundsen-frontend:{YYYY-MM-dd}
    +docker push <?>.dkr.ecr.<?>.amazonaws.com/amundsen-frontend:{YYYY-MM-dd}
    +

    +
  • +
  • +

    Go to the helm/{env}/amundsen/values.yaml and modify to the image tag that you want to use.

    +
  • +
  • +

    When updating amundsen-frontend, make sure to do a hard refresh of amundsen with emptying the cache, +otherwise you will see stale version of webpage.

    +
  • +
+

Test search service in local using staging or production data

+

To test in local, we need to stand up Elasticsearch, publish index data, and stand up Elastic search

+

Standup Elasticsearch

+

Running Elasticsearch via Docker. To install Docker, go here +Example:

+
1
docker run -p 9200:9200  -p 9300:9300  -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:6.2.4
+
+ +
(Optional) Standup Kibana
+
1
docker run --link ecstatic_edison:elasticsearch -p 5601:5601 docker.elastic.co/kibana/kibana:6.2.4
+
+ +

*Note that ecstatic_edison is container_id of Elasticsearch container. Update it if it’s different by looking at docker ps

+

Publish Table index through Databuilder

+
Install Databuilder
+
1
+2
+3
+4
+5
+6
+7
cd ~/src/
+git clone git@github.com:amundsen-io/amundsendatabuilder.git
+cd ~/src/amundsendatabuilder
+virtualenv venv
+source venv/bin/activate
+python setup.py install
+pip install -r requirements.txt
+
+ +
Publish Table index
+

First fill this two environment variables: NEO4J_ENDPOINT , CREDENTIALS_NEO4J_PASSWORD

+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
$ python
+
+import logging  
+import os  
+import uuid
+
+from elasticsearch import Elasticsearch  
+from pyhocon import ConfigFactory
+
+from databuilder.extractor.neo4j_extractor import Neo4jExtractor  
+from databuilder.extractor.neo4j_search_data_extractor import Neo4jSearchDataExtractor  
+from databuilder.job.job import DefaultJob  
+from databuilder.loader.file_system_elasticsearch_json_loader import FSElasticsearchJSONLoader  
+from databuilder.publisher.elasticsearch_publisher import ElasticsearchPublisher  
+from databuilder.task.task import DefaultTask
+
+logging.basicConfig(level=logging.INFO)
+
+neo4j_user = 'neo4j'  
+neo4j_password = os.getenv('CREDENTIALS_NEO4J_PASSWORD')  
+neo4j_endpoint = os.getenv('NEO4J_ENDPOINT')
+
+elasticsearch_client = Elasticsearch([  
+    {'host': 'localhost'},  
+])
+
+data_file_path = '/var/tmp/amundsen/elasticsearch_upload/es_data.json'
+
+elasticsearch_new_index = 'table_search_index_{hex_str}'.format(hex_str=uuid.uuid4().hex)
+logging.info("Elasticsearch new index: " + elasticsearch_new_index)
+
+elasticsearch_doc_type = 'table'  
+elasticsearch_index_alias = 'table_search_index'
+
+job_config = ConfigFactory.from_dict({  
+    'extractor.search_data.extractor.neo4j.{}'.format(Neo4jExtractor.GRAPH_URL_CONFIG_KEY):  
+        neo4j_endpoint,  
+  'extractor.search_data.extractor.neo4j.{}'.format(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY):  
+        'databuilder.models.table_elasticsearch_document.TableESDocument',  
+  'extractor.search_data.extractor.neo4j.{}'.format(Neo4jExtractor.NEO4J_AUTH_USER):  
+        neo4j_user,  
+  'extractor.search_data.extractor.neo4j.{}'.format(Neo4jExtractor.NEO4J_AUTH_PW):  
+        neo4j_password,  
+  'loader.filesystem.elasticsearch.{}'.format(FSElasticsearchJSONLoader.FILE_PATH_CONFIG_KEY):  
+        data_file_path,  
+  'loader.filesystem.elasticsearch.{}'.format(FSElasticsearchJSONLoader.FILE_MODE_CONFIG_KEY):  
+        'w',  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.FILE_PATH_CONFIG_KEY):  
+        data_file_path,  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.FILE_MODE_CONFIG_KEY):  
+        'r',  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.ELASTICSEARCH_CLIENT_CONFIG_KEY):  
+        elasticsearch_client,  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.ELASTICSEARCH_NEW_INDEX_CONFIG_KEY):  
+        elasticsearch_new_index,  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.ELASTICSEARCH_DOC_TYPE_CONFIG_KEY):  
+        elasticsearch_doc_type,  
+  'publisher.elasticsearch.{}'.format(ElasticsearchPublisher.ELASTICSEARCH_ALIAS_CONFIG_KEY):  
+        elasticsearch_index_alias,  
+})
+
+job = DefaultJob(conf=job_config,  
+  task=DefaultTask(extractor=Neo4jSearchDataExtractor(),  
+  loader=FSElasticsearchJSONLoader()),  
+  publisher=ElasticsearchPublisher())  
+if neo4j_password:  
+    job.launch()  
+else:  
+    raise ValueError('Add environment variable CREDENTIALS_NEO4J_PASSWORD')
+
+ +

Standup Search service

+

Follow this instruction

+

Test the search API with this command:

+
1
curl -vv "http://localhost:5001/search?query_term=test&page_index=0"
+
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/faq/index.html b/faq/index.html new file mode 100644 index 0000000000..184f08adae --- /dev/null +++ b/faq/index.html @@ -0,0 +1,1616 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + FAQ - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + + + + + +
+
+ + + +  + + + + + + +

FAQ

+

How to select between Neo4j and Atlas as backend for Amundsen?

+

Why Neo4j?

+
    +
  1. Amundsen has direct influence over the data model if you use neo4j. This, at least initially, will benefit the speed by which new features in amundsen can arrive.
  2. +
  3. Neo4j for it is the market leader in Graph database and also was proven by Airbnb’s Data portal on their Data discovery tool.
  4. +
+

Why Atlas?

+
    +
  1. Atlas has lineage support already available. It’s been tried and tested.
  2. +
  3. Tag/Badge propagation is supported.
  4. +
  5. It has a robust authentication and authorization system.
  6. +
  7. Atlas does data governance adding Amundsen for discovery makes it best of both worlds.
  8. +
  9. Apache Atlas is the only proxy in Amundsen supporting both push and pull approaches for collecting metadata:
      +
    • Push method by leveraging Apache Atlas Hive Hook. It’s an event listener running alongside Hive Metastore, translating Hive Metastore events into Apache Atlas entities and pushing them to Kafka topic, from which Apache Atlas ingests the data by internal processes.
    • +
    • Pull method by leveraging Amundsen Databuilder integration with Apache Atlas. It means that extractors available in Databuilder can be used to collect metadata about external systems (like PostgresMetadataExtractor) and sending them to Apache Atlas in a shape consumable by Amundsen. +Amundsen <> Atlas integration is prepared in such way that you can use both push and pull models at the same time.
    • +
    +
  10. +
  11. The free version of Neo4j does not have authorization support (Enterprise version does). Your question should actually be why use “neo4j over janusgraph” cause that is the right level of comparison. Atlas adds a whole bunch on top of the graph database.
  12. +
+

Why not Atlas?

+
    +
  1. Atlas seems to have a slow development cycle and it’s community is not very responsive although some small improvements have been made.
  2. +
  3. Atlas integration has less community support meaning new features might land slightly later for Atlas in comparison to Neo4j
  4. +
+

What are the prerequisites to use Apache Atlas as backend for Amundsen?

+

To run Amundsen with Atlas, latest versions of following components should be used: +1. Apache Atlas - built from master branch. Ref 103e867cc126ddb84e64bf262791a01a55bee6e5 (or higher). +2. amundsenatlastypes - library for installing Atlas entity definitions specific to Amundsen integration. Version 1.3.0 (or higher).

+

How to migrate from Amundsen 1.x -> 2.x?

+

v2.0 renames a handful of fields in the services to be more consistent. Unfortunately one side effect is that the 2.0 versions of the services will need to be deployed simultaneously, as they are not interoperable with the 1.x versions.

+

Additionally, some indexed field names in the elasticsearch document change as well, so if you’re using elasticsearch, you’ll need to republish Elasticsearch index via Databuilder job.

+

The data in the metadata store, however, can be preserved when migrating from 1.x to 2.0.

+

v2.0 deployments consists of deployment of all three services along with republishing Elasticsearch document on Table with v2.0 Databuilder.

+

Keep in mind there is likely to be some downtime as v2.0.0, between deploying 3 services and re-seeding the elasticsearch indexes, so it might be ideal to stage a rollout by datacenter/environment if uptime is key

+

How to avoid certain metadatas in Amundsen got erased by databuilder ingestion?

+

By default, databuilder always upserts the metadata. If you want to prevent that happens on certain type of metadata, you could add the following +config to your databuilder job’s config

+
'publisher.neo4j.{}'.format(neo4j_csv_publisher.NEO4J_CREATE_ONLY_NODES): [DESCRIPTION_NODE_LABEL],
+
+

This config means that databuilder will only update the table / column description if it doesn’t exist before which could be the table is newly created. +This is useful when we treat Amundsen graph as the source of truth for certain types of metadata (e.g description).

+

How to capture all Google Analytics?

+

Users are likely to have some sort of adblocker installed, making your Google Analytics less accurate.

+

To put a proxy in place to bypass any adblockers and capture all analytics, follow these steps:

+
    +
  1. Follow https://github.com/ZitRos/save-analytics-from-content-blockers#setup to set up your own proxy server.
  2. +
  3. In the same repository, run npm run mask www.googletagmanager.com/gtag/js?id=UA-XXXXXXXXX and save the output.
  4. +
  5. In your custom frontend, override https://github.com/amundsen-io/amundsenfrontendlibrary/blob/master/amundsen_application/static/templates/fragments/google-analytics-loader.html#L6 to
  6. +
  7. Now, note that network requests to www.googletagmanager.com will be sent from behind your masked proxy endpoint, saving your analytics from content blockers!
  8. +
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/frontend/CHANGELOG/index.html b/frontend/CHANGELOG/index.html new file mode 100644 index 0000000000..d394415b26 --- /dev/null +++ b/frontend/CHANGELOG/index.html @@ -0,0 +1,1426 @@ + + + + + + + + + + + + + + + + + + + + + + + + CHANGELOG - Amundsen + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + +  + + + + +

CHANGELOG

+ + + +

Feature

+
    +
  • Table and Column Lineage Polish (#970) (cd2f4c4)
  • +
  • Table and Column Lineage Lists (#969) (df9532a)
  • +
  • Add Table Notices (#957) (e3be638)
  • +
  • Allows for splitting stats’ distinct values into a different element that shows in modal (#960) (fe04a06)
  • +
+

Fix

+
    +
  • Upgrade mypy version to build with Python3.8 (#975) (18963ec)
  • +
  • Handles parsing errors when format not expected on distinct values (#966) (473bbdb)
  • +
  • Made commit author consistent (#917) (48441cd)
  • +
  • Yaml syntax error (#913) (8f49627)
  • +
  • Add chore to monthly release PRs (#912) (9323862)
  • +
  • Removed echo for changelog command (#910) (bb22d4d)
  • +
  • Add changelog file (#907) (f06c50e)
  • +
  • Made change to preserve format of changelog (#896) (0d56d72)
  • +
  • Fixed reviewers field syntax error (#892) (b7f99d4)
  • +
  • Made branch eval and added reviewers (#891) (dd57d44)
  • +
  • Changed release workflow completely (#882) (5dfcd09)
  • +
  • Index tag info into elasticsearch immediately after ui change (#883) (b34151c)
  • +
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/frontend/LICENSE b/frontend/LICENSE new file mode 100644 index 0000000000..a1c70dc855 --- /dev/null +++ b/frontend/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 Lyft, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/frontend/MANIFEST.in b/frontend/MANIFEST.in new file mode 100644 index 0000000000..407c9fbce2 --- /dev/null +++ b/frontend/MANIFEST.in @@ -0,0 +1,9 @@ +recursive-include amundsen_application/static/dist * +recursive-include amundsen_application/static/fonts * +recursive-include amundsen_application/static/images * + +recursive-include amundsen_application/.*/static/dist * +recursive-include amundsen_application/.*/static/fonts * +recursive-include amundsen_application/.*/static/images * + +global-include requirements-*txt diff --git a/frontend/Makefile b/frontend/Makefile new file mode 100644 index 0000000000..e6179d4f90 --- /dev/null +++ b/frontend/Makefile @@ -0,0 +1,58 @@ +IMAGE := amundsendev/amundsen-frontend +OIDC_IMAGE := ${IMAGE}-oidc +VERSION:= $(shell grep -m 1 '__version__' setup.py | cut -d '=' -f 2 | tr -d "'" | tr -d '[:space:]') + +.PHONY: clean +clean: + find . -name \*.pyc -delete + find . -name __pycache__ -delete + rm -rf dist/ + +.PHONY: test_unit +test_unit: + python3 -bb -m pytest tests + +.PHONY: lint +lint: + flake8 . + +.PHONY: mypy +mypy: + mypy --ignore-missing-imports --follow-imports=skip --strict-optional --warn-no-return . + +.PHONY: test +test: test_unit lint mypy + +.PHONY: image +image: + cd .. && docker build -f Dockerfile.frontend.public -t ${IMAGE}:latest . && cd frontend + +.PHONY: image-version +image-version: + cd .. && docker build -f Dockerfile.frontend.public -t ${IMAGE}:${VERSION} . && cd frontend + +.PHONY: push-image-version +push-image-version: + docker push ${IMAGE}:${VERSION} + +.PHONY: push-image +push-image: + docker push ${IMAGE}:latest + +.PHONY: oidc-image +oidc-image: + cd .. && docker build -f Dockerfile.frontend.public --target=oidc-release -t ${OIDC_IMAGE}:${VERSION} . && cd frontend + docker tag ${OIDC_IMAGE}:${VERSION} ${OIDC_IMAGE}:latest + +.PHONY: push-odic-image +push-oidc-image: + docker push ${OIDC_IMAGE}:${VERSION} + docker push ${OIDC_IMAGE}:latest + +.PHONY: build-push-image +build-push-image-latest: image oidc-image push-image push-oidc-image +build-push-image-version: image-version push-image-version + +.PHONY: install_deps +install_deps: + pip3 install -e ".[all]" diff --git a/frontend/NOTICE b/frontend/NOTICE new file mode 100644 index 0000000000..13f64901b8 --- /dev/null +++ b/frontend/NOTICE @@ -0,0 +1,4 @@ +amundsenfrontendlibrary +Copyright 2018-2019 Lyft Inc. + +This product includes software developed at Lyft Inc. diff --git a/frontend/amundsen_application/__init__.py b/frontend/amundsen_application/__init__.py new file mode 100644 index 0000000000..6acdaf0d93 --- /dev/null +++ b/frontend/amundsen_application/__init__.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import ast +import importlib +import logging +import logging.config +import os +import sys + +from flask import Blueprint, Flask +from flask_restful import Api +from typing import Optional + +from amundsen_application.api import init_routes +from amundsen_application.api.announcements.v0 import announcements_blueprint +from amundsen_application.api.issue.issue import IssueAPI, IssuesAPI +from amundsen_application.api.log.v0 import log_blueprint +from amundsen_application.api.mail.v0 import mail_blueprint +from amundsen_application.api.metadata.v0 import metadata_blueprint +from amundsen_application.api.preview.dashboard.v0 import \ + dashboard_preview_blueprint +from amundsen_application.api.preview.v0 import preview_blueprint +from amundsen_application.api.quality.v0 import quality_blueprint +from amundsen_application.api.search.v1 import search_blueprint +from amundsen_application.api.notice.v0 import notices_blueprint +from amundsen_application.api.v0 import blueprint +from amundsen_application.deprecations import process_deprecations + +# For customized flask use below arguments to override. + +FLASK_APP_MODULE_NAME = os.getenv('FLASK_APP_MODULE_NAME') or os.getenv('APP_WRAPPER') +FLASK_APP_CLASS_NAME = os.getenv('FLASK_APP_CLASS_NAME') or os.getenv('APP_WRAPPER_CLASS') +FLASK_APP_KWARGS_DICT_STR = os.getenv('FLASK_APP_KWARGS_DICT') or os.getenv('APP_WRAPPER_ARGS') + +""" Support for importing a subclass of flask.Flask, via env variables """ +if FLASK_APP_MODULE_NAME and FLASK_APP_CLASS_NAME: + print('Using requested Flask module {module_name} and class {class_name}' + .format(module_name=FLASK_APP_MODULE_NAME, class_name=FLASK_APP_CLASS_NAME), file=sys.stderr) + moduleName = FLASK_APP_MODULE_NAME + module = importlib.import_module(moduleName) + moduleClass = FLASK_APP_CLASS_NAME + app_wrapper_class = getattr(module, moduleClass) # type: ignore +else: + app_wrapper_class = Flask + +PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) +STATIC_ROOT = os.getenv('STATIC_ROOT', 'static') +static_dir = os.path.join(PROJECT_ROOT, STATIC_ROOT) + + +def create_app(config_module_class: Optional[str] = None, template_folder: Optional[str] = None) -> Flask: + """ Support for importing arguments for a subclass of flask.Flask """ + args = ast.literal_eval(FLASK_APP_KWARGS_DICT_STR) if FLASK_APP_KWARGS_DICT_STR else {} + + tmpl_dir = template_folder if template_folder else os.path.join(PROJECT_ROOT, static_dir, 'dist/templates') + app = app_wrapper_class(__name__, static_folder=static_dir, template_folder=tmpl_dir, **args) + + # Support for importing a custom config class + if not config_module_class: + config_module_class = os.getenv('FRONTEND_SVC_CONFIG_MODULE_CLASS') + + app.config.from_object(config_module_class) + + if app.config.get('LOG_CONFIG_FILE'): + logging.config.fileConfig(app.config['LOG_CONFIG_FILE'], disable_existing_loggers=False) + else: + logging.basicConfig(format=app.config['LOG_FORMAT'], datefmt=app.config.get('LOG_DATE_FORMAT')) + logging.getLogger().setLevel(app.config['LOG_LEVEL']) + + logging.info('Created app with config name {}'.format(config_module_class)) + logging.info('Using metadata service at {}'.format(app.config.get('METADATASERVICE_BASE'))) + logging.info('Using search service at {}'.format(app.config.get('SEARCHSERVICE_BASE'))) + + api_bp = Blueprint('api', __name__) + api = Api(api_bp) + + api.add_resource(IssuesAPI, + '/api/issue/issues', endpoint='issues') + api.add_resource(IssueAPI, + '/api/issue/issue', endpoint='issue') + + app.register_blueprint(blueprint) + app.register_blueprint(announcements_blueprint) + app.register_blueprint(log_blueprint) + app.register_blueprint(mail_blueprint) + app.register_blueprint(metadata_blueprint) + app.register_blueprint(preview_blueprint) + app.register_blueprint(quality_blueprint) + app.register_blueprint(search_blueprint) + app.register_blueprint(notices_blueprint) + app.register_blueprint(api_bp) + app.register_blueprint(dashboard_preview_blueprint) + init_routes(app) + + init_custom_routes = app.config.get('INIT_CUSTOM_ROUTES') + if init_custom_routes: + init_custom_routes(app) + + # handles the deprecation warnings + # and process any config/environment variables accordingly + process_deprecations(app) + + return app diff --git a/frontend/amundsen_application/api/__init__.py b/frontend/amundsen_application/api/__init__.py new file mode 100644 index 0000000000..6c5e44c724 --- /dev/null +++ b/frontend/amundsen_application/api/__init__.py @@ -0,0 +1,51 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Tuple +import logging + +from flask import Flask, render_template, make_response +import jinja2 +import os + + +ENVIRONMENT = os.getenv('APPLICATION_ENV', 'development') +LOGGER = logging.getLogger(__name__) + + +def init_routes(app: Flask) -> None: + frontend_base = app.config.get('FRONTEND_BASE') + config_override_enabled = app.config.get('JS_CONFIG_OVERRIDE_ENABLED') + + app.add_url_rule('/healthcheck', 'healthcheck', healthcheck) + app.add_url_rule('/opensearch.xml', 'opensearch.xml', opensearch, defaults={'frontend_base': frontend_base}) + app.add_url_rule('/', 'index', index, defaults={'path': '', + 'config_override_enabled': config_override_enabled, + 'frontend_base': frontend_base}) # also functions as catch_all + app.add_url_rule('/', 'index', index, + defaults={'frontend_base': frontend_base, + 'config_override_enabled': config_override_enabled}) # catch_all + + +def index(path: str, frontend_base: str, config_override_enabled: bool) -> Any: + try: + return render_template("index.html", env=ENVIRONMENT, frontend_base=frontend_base, + config_override_enabled=config_override_enabled) # pragma: no cover + except jinja2.exceptions.TemplateNotFound as e: + LOGGER.error("index.html template not found, have you built the front-end JS (npm run build in static/?") + raise e + + +def healthcheck() -> Tuple[str, int]: + return '', 200 # pragma: no cover + + +def opensearch(frontend_base: str) -> Any: + try: + template = render_template("opensearch.xml", frontend_base=frontend_base) + response = make_response(template) + response.headers['Content-Type'] = 'application/xml' + return response + except jinja2.exceptions.TemplateNotFound as e: + LOGGER.error("opensearch.xml template not found, have you built the front-end JS (npm run build in static/?") + raise e diff --git a/frontend/amundsen_application/api/announcements/__init__.py b/frontend/amundsen_application/api/announcements/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/announcements/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/announcements/v0.py b/frontend/amundsen_application/api/announcements/v0.py new file mode 100644 index 0000000000..cdf59f45af --- /dev/null +++ b/frontend/amundsen_application/api/announcements/v0.py @@ -0,0 +1,48 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pkg_resources import iter_entry_points + +from http import HTTPStatus + +from flask import Response, jsonify, make_response, current_app as app +from flask.blueprints import Blueprint +from werkzeug.utils import import_string + +LOGGER = logging.getLogger(__name__) +ANNOUNCEMENT_CLIENT_CLASS = None +ANNOUNCEMENT_CLIENT_INSTANCE = None + +for entry_point in iter_entry_points(group='announcement_client', name='announcement_client_class'): + announcement_client_class = entry_point.load() + if announcement_client_class is not None: + ANNOUNCEMENT_CLIENT_CLASS = announcement_client_class + +announcements_blueprint = Blueprint('announcements', __name__, url_prefix='/api/announcements/v0') + + +@announcements_blueprint.route('/', methods=['GET']) +def get_announcements() -> Response: + global ANNOUNCEMENT_CLIENT_INSTANCE + global ANNOUNCEMENT_CLIENT_CLASS + try: + if ANNOUNCEMENT_CLIENT_INSTANCE is None: + if ANNOUNCEMENT_CLIENT_CLASS is not None: + ANNOUNCEMENT_CLIENT_INSTANCE = ANNOUNCEMENT_CLIENT_CLASS() + logging.warn('Setting announcement_client via entry_point is DEPRECATED' + ' and will be removed in a future version') + elif (app.config['ANNOUNCEMENT_CLIENT_ENABLED'] + and app.config['ANNOUNCEMENT_CLIENT'] is not None): + ANNOUNCEMENT_CLIENT_CLASS = import_string(app.config['ANNOUNCEMENT_CLIENT']) + ANNOUNCEMENT_CLIENT_INSTANCE = ANNOUNCEMENT_CLIENT_CLASS() + else: + payload = jsonify({'posts': [], + 'msg': 'A client for retrieving announcements must be configured'}) + return make_response(payload, HTTPStatus.NOT_IMPLEMENTED) + return ANNOUNCEMENT_CLIENT_INSTANCE._get_posts() + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'posts': [], 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/exceptions.py b/frontend/amundsen_application/api/exceptions.py new file mode 100644 index 0000000000..2006671f88 --- /dev/null +++ b/frontend/amundsen_application/api/exceptions.py @@ -0,0 +1,9 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +class MailClientNotImplemented(Exception): + """ + An exception when Mail Client is not implemented + """ + pass diff --git a/frontend/amundsen_application/api/issue/__init__.py b/frontend/amundsen_application/api/issue/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/issue/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/issue/issue.py b/frontend/amundsen_application/api/issue/issue.py new file mode 100644 index 0000000000..987bf1bc2a --- /dev/null +++ b/frontend/amundsen_application/api/issue/issue.py @@ -0,0 +1,90 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from flask import current_app as app +from flask import jsonify, make_response, Response +from flask_restful import Resource, reqparse +from http import HTTPStatus +import logging + +from amundsen_application.base.base_issue_tracker_client import BaseIssueTrackerClient +from amundsen_application.proxy.issue_tracker_clients import get_issue_tracker_client +from amundsen_application.proxy.issue_tracker_clients.issue_exceptions import IssueConfigurationException + +LOGGER = logging.getLogger(__name__) + + +class IssuesAPI(Resource): + def __init__(self) -> None: + self.reqparse = reqparse.RequestParser() + self.client: BaseIssueTrackerClient + + def get(self) -> Response: + """ + Given a table key, returns all tickets containing that key. Returns an empty array if none exist + :return: List of tickets + """ + try: + if not app.config['ISSUE_TRACKER_CLIENT_ENABLED']: + message = 'Issue tracking is not enabled. Request was accepted but no issue will be returned.' + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.ACCEPTED) + + self.client = get_issue_tracker_client() + self.reqparse.add_argument('key', 'Request requires a key', location='args') + args = self.reqparse.parse_args() + response = self.client.get_issues(args['key']) + return make_response(jsonify({'issues': response.serialize()}), HTTPStatus.OK) + + except IssueConfigurationException as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.NOT_IMPLEMENTED) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +class IssueAPI(Resource): + def __init__(self) -> None: + self.reqparse = reqparse.RequestParser() + self.client: BaseIssueTrackerClient + super(IssueAPI, self).__init__() + + def post(self) -> Response: + try: + if not app.config['ISSUE_TRACKER_CLIENT_ENABLED']: + message = 'Issue tracking is not enabled. Request was accepted but no issue will be created.' + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.ACCEPTED) + self.client = get_issue_tracker_client() + + self.reqparse.add_argument('title', type=str, location='json') + self.reqparse.add_argument('key', type=str, location='json') + self.reqparse.add_argument('description', type=str, location='json') + self.reqparse.add_argument('owner_ids', type=list, location='json') + self.reqparse.add_argument('frequent_user_ids', type=list, location='json') + self.reqparse.add_argument('priority_level', type=str, location='json') + self.reqparse.add_argument('project_key', type=str, location='json') + self.reqparse.add_argument('resource_path', type=str, location='json') + args = self.reqparse.parse_args() + response = self.client.create_issue(description=args['description'], + owner_ids=args['owner_ids'], + frequent_user_ids=args['frequent_user_ids'], + priority_level=args['priority_level'], + project_key=args['project_key'], + table_uri=args['key'], + title=args['title'], + table_url=app.config['FRONTEND_BASE'] + args['resource_path'] + if args['resource_path'] else 'Not Found') + return make_response(jsonify({'issue': response.serialize()}), HTTPStatus.OK) + + except IssueConfigurationException as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.NOT_IMPLEMENTED) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/log/__init__.py b/frontend/amundsen_application/api/log/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/log/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/log/v0.py b/frontend/amundsen_application/api/log/v0.py new file mode 100644 index 0000000000..05888ec3b2 --- /dev/null +++ b/frontend/amundsen_application/api/log/v0.py @@ -0,0 +1,64 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from http import HTTPStatus + +from flask import Response, jsonify, make_response, request +from flask.blueprints import Blueprint + +from amundsen_application.log.action_log import action_logging +from amundsen_application.api.utils.request_utils import get_query_param + + +LOGGER = logging.getLogger(__name__) + +log_blueprint = Blueprint('log', __name__, url_prefix='/api/log/v0') + + +@log_blueprint.route('/log_event', methods=['POST']) +def log_generic_action() -> Response: + """ + Log a generic action on the frontend. Captured parameters include + + :param command: Req. User Action E.g. click, scroll, hover, search, etc + :param target_id: Req. Unique identifier for the object acted upon E.g. tag::payments, table::schema.database + :param target_type: Opt. Type of element event took place on (button, link, tag, icon, etc) + :param label: Opt. Displayed text for target + :param location: Opt. Where the the event occurred + :param value: Opt. Value to be logged + :return: + """ + @action_logging + def _log_generic_action(*, + command: str, + target_id: str, + target_type: str, + label: str, + location: str, + value: str, + position: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + command = get_query_param(args, 'command', '"command" is a required parameter.') + target_id = get_query_param(args, 'target_id', '"target_id" is a required field.') + _log_generic_action( + command=command, + target_id=target_id, + target_type=args.get('target_type', None), + label=args.get('label', None), + location=args.get('location', None), + value=args.get('value', None), + position=args.get('position', None) + ) + message = 'Logging of {} action successful'.format(command) + return make_response(jsonify({'msg': message}), HTTPStatus.OK) + + except Exception as e: + message = 'Log action failed. Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/mail/__init__.py b/frontend/amundsen_application/api/mail/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/mail/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/mail/v0.py b/frontend/amundsen_application/api/mail/v0.py new file mode 100644 index 0000000000..fe0bcc9da1 --- /dev/null +++ b/frontend/amundsen_application/api/mail/v0.py @@ -0,0 +1,121 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from http import HTTPStatus + +from flask import Response, jsonify, make_response, request +from flask import current_app as app +from flask.blueprints import Blueprint + +from amundsen_application.api.exceptions import MailClientNotImplemented +from amundsen_application.api.utils.notification_utils import get_mail_client, send_notification +from amundsen_application.log.action_log import action_logging + +LOGGER = logging.getLogger(__name__) + +mail_blueprint = Blueprint('mail', __name__, url_prefix='/api/mail/v0') + + +@mail_blueprint.route('/feedback', methods=['POST']) +def feedback() -> Response: + """ + Uses the instance of BaseMailClient client configured on the MAIL_CLIENT + config variable to send an email with feedback data + """ + try: + mail_client = get_mail_client() + data = request.form.to_dict() + html_content = ''.join('
{}:
{}

'.format(k, v) for k, v in data.items()) + + # action logging + feedback_type = data.get('feedback-type') + rating = data.get('rating') + comment = data.get('comment') + bug_summary = data.get('bug-summary') + repro_steps = data.get('repro-steps') + feature_summary = data.get('feature-summary') + value_prop = data.get('value-prop') + subject = data.get('subject') or data.get('feedback-type') + + _feedback(feedback_type=feedback_type, + rating=rating, + comment=comment, + bug_summary=bug_summary, + repro_steps=repro_steps, + feature_summary=feature_summary, + value_prop=value_prop, + subject=subject) + + options = { + 'email_type': 'feedback', + 'form_data': data + } + + response = mail_client.send_email(html=html_content, subject=subject, optional_data=options) + status_code = response.status_code + + if 200 <= status_code < 300: + message = 'Success' + else: + message = 'Mail client failed with status code ' + str(status_code) + logging.error(message) + + return make_response(jsonify({'msg': message}), status_code) + except MailClientNotImplemented as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.NOT_IMPLEMENTED) + except Exception as e1: + message = 'Encountered exception: ' + str(e1) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@action_logging +def _feedback(*, + feedback_type: str, + rating: str, + comment: str, + bug_summary: str, + repro_steps: str, + feature_summary: str, + value_prop: str, + subject: str) -> None: + """ Logs the content of the feedback form """ + pass # pragma: no cover + + +@mail_blueprint.route('/notification', methods=['POST']) +def notification() -> Response: + """ + Uses the instance of BaseMailClient client configured on the MAIL_CLIENT + config variable to send a notification email based on data passed from the request + """ + try: + data = request.get_json() + + notification_type = data.get('notificationType') + if notification_type is None: + message = 'Encountered exception: notificationType must be provided in the request payload' + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.BAD_REQUEST) + + sender = data.get('sender') + if sender is None: + sender = app.config['AUTH_USER_METHOD'](app).email + + options = data.get('options', {}) + recipients = data.get('recipients', []) + + return send_notification( + notification_type=notification_type, + options=options, + recipients=recipients, + sender=sender + ) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/metadata/__init__.py b/frontend/amundsen_application/api/metadata/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/metadata/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/metadata/v0.py b/frontend/amundsen_application/api/metadata/v0.py new file mode 100644 index 0000000000..b842ccabff --- /dev/null +++ b/frontend/amundsen_application/api/metadata/v0.py @@ -0,0 +1,1173 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import json + +from http import HTTPStatus +from typing import Any, Dict, Optional + +from flask import Response, jsonify, make_response, request +from flask import current_app as app +from flask.blueprints import Blueprint + +from amundsen_common.entity.resource_type import ResourceType, to_label +from amundsen_common.models.search import UpdateDocumentRequestSchema, UpdateDocumentRequest + +from amundsen_application.log.action_log import action_logging + +from amundsen_application.models.user import load_user, dump_user + +from amundsen_application.api.utils.metadata_utils import is_table_editable, marshall_table_partial, \ + marshall_table_full, marshall_dashboard_partial, marshall_dashboard_full, marshall_feature_full, \ + marshall_lineage_table, TableUri +from amundsen_application.api.utils.request_utils import get_query_param, request_metadata + +from amundsen_application.api.utils.search_utils import execute_search_document_request + + +LOGGER = logging.getLogger(__name__) + + +metadata_blueprint = Blueprint('metadata', __name__, url_prefix='/api/metadata/v0') + +TABLE_ENDPOINT = '/table' +TYPE_METADATA_ENDPOINT = '/type_metadata' +FEATURE_ENDPOINT = '/feature' +LAST_INDEXED_ENDPOINT = '/latest_updated_ts' +POPULAR_RESOURCES_ENDPOINT = '/popular_resources' +TAGS_ENDPOINT = '/tags/' +BADGES_ENDPOINT = '/badges/' +USER_ENDPOINT = '/user' +DASHBOARD_ENDPOINT = '/dashboard' + + +def _get_table_endpoint() -> str: + metadata_service_base = app.config['METADATASERVICE_BASE'] + if metadata_service_base is None: + raise Exception('METADATASERVICE_BASE must be configured') + return metadata_service_base + TABLE_ENDPOINT + + +def _get_type_metadata_endpoint() -> str: + metadata_service_base = app.config['METADATASERVICE_BASE'] + if metadata_service_base is None: + raise Exception('METADATASERVICE_BASE must be configured') + return metadata_service_base + TYPE_METADATA_ENDPOINT + + +def _get_feature_endpoint() -> str: + metadata_service_base = app.config['METADATASERVICE_BASE'] + if metadata_service_base is None: + raise Exception('METADATASERVICE_BASE must be configured') + return metadata_service_base + FEATURE_ENDPOINT + + +def _get_dashboard_endpoint() -> str: + metadata_service_base = app.config['METADATASERVICE_BASE'] + if metadata_service_base is None: + raise Exception('METADATASERVICE_BASE must be configured') + return metadata_service_base + DASHBOARD_ENDPOINT + + +@metadata_blueprint.route('/popular_resources', methods=['GET']) +def popular_resources() -> Response: + """ + call the metadata service endpoint to get the current popular tables, dashboards etc. + this takes a required query parameter "types", that is a comma separated string of requested resource types + :return: a json output containing an array of popular table metadata as 'popular_tables' + + Schema Defined Here: + https://github.com/lyft/amundsenmetadatalibrary/blob/master/metadata_service/api/popular_tables.py + """ + try: + if app.config['AUTH_USER_METHOD'] and app.config['POPULAR_RESOURCES_PERSONALIZATION']: + user_id = app.config['AUTH_USER_METHOD'](app).user_id + else: + user_id = '' + + resource_types = get_query_param(request.args, 'types') + + service_base = app.config['METADATASERVICE_BASE'] + count = app.config['POPULAR_RESOURCES_COUNT'] + url = f'{service_base}{POPULAR_RESOURCES_ENDPOINT}/{user_id}?limit={count}&types={resource_types}' + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + json_response = response.json() + tables = json_response.get(ResourceType.Table.name, []) + popular_tables = [marshall_table_partial(result) for result in tables] + dashboards = json_response.get(ResourceType.Dashboard.name, []) + popular_dashboards = [marshall_dashboard_partial(dashboard) for dashboard in dashboards] + else: + message = 'Encountered error: Request to metadata service failed with status code ' + str(status_code) + logging.error(message) + popular_tables = [] + popular_dashboards = [] + + all_popular_resources = { + to_label(resource_type=ResourceType.Table): popular_tables, + to_label(resource_type=ResourceType.Dashboard): popular_dashboards + } + + payload = jsonify({'results': all_popular_resources, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'results': [{}], 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/table', methods=['GET']) +def get_table_metadata() -> Response: + """ + call the metadata service endpoint and return matching results + :return: a json output containing a table metadata object as 'tableData' + + Schema Defined Here: https://github.com/lyft/amundsenmetadatalibrary/blob/master/metadata_service/api/table.py + TODO: Define type for this + + TODO: Define an interface for envoy_client + """ + try: + table_key = get_query_param(request.args, 'key') + list_item_index = request.args.get('index', None) + list_item_source = request.args.get('source', None) + + results_dict = _get_table_metadata(table_key=table_key, index=list_item_index, source=list_item_source) + return make_response(jsonify(results_dict), results_dict.get('status_code', HTTPStatus.INTERNAL_SERVER_ERROR)) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'tableData': {}, 'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@action_logging +def _get_table_metadata(*, table_key: str, index: int, source: str) -> Dict[str, Any]: + + results_dict: Dict[str, Any] = { + 'tableData': {}, + 'msg': '', + } + + try: + table_endpoint = _get_table_endpoint() + url = '{0}/{1}'.format(table_endpoint, table_key) + response = request_metadata(url=url) + except ValueError as e: + # envoy client BadResponse is a subclass of ValueError + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + logging.exception(message) + return results_dict + + status_code = response.status_code + results_dict['status_code'] = status_code + + if status_code != HTTPStatus.OK: + message = 'Encountered error: Metadata request failed' + results_dict['msg'] = message + logging.error(message) + return results_dict + + try: + table_data_raw: dict = response.json() + + # Ideally the response should include 'key' to begin with + table_data_raw['key'] = table_key + + results_dict['tableData'] = marshall_table_full(table_data_raw) + results_dict['msg'] = 'Success' + return results_dict + except Exception as e: + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + logging.exception(message) + # explicitly raise the exception which will trigger 500 api response + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + return results_dict + + +@metadata_blueprint.route('/update_table_owner', methods=['PUT', 'DELETE']) +def update_table_owner() -> Response: + + @action_logging + def _log_update_table_owner(*, table_key: str, method: str, owner: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + table_key = get_query_param(args, 'key') + owner = get_query_param(args, 'owner') + + table_endpoint = _get_table_endpoint() + url = '{0}/{1}/owner/{2}'.format(table_endpoint, table_key, owner) + method = request.method + _log_update_table_owner(table_key=table_key, method=method, owner=owner) + + response = request_metadata(url=url, method=method) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Updated owner' + else: + message = 'There was a problem updating owner {0}'.format(owner) + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_last_indexed') +def get_last_indexed() -> Response: + """ + call the metadata service endpoint to get the last indexed timestamp of neo4j + :return: a json output containing the last indexed timestamp, in unix epoch time, as 'timestamp' + + Schema Defined Here: https://github.com/lyft/amundsenmetadatalibrary/blob/master/metadata_service/api/system.py + """ + try: + url = app.config['METADATASERVICE_BASE'] + LAST_INDEXED_ENDPOINT + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + timestamp = response.json().get('neo4j_latest_timestamp') + else: + message = 'Timestamp Unavailable' + timestamp = None + + payload = jsonify({'timestamp': timestamp, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'timestamp': None, 'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_table_description', methods=['GET']) +def get_table_description() -> Response: + try: + table_endpoint = _get_table_endpoint() + table_key = get_query_param(request.args, 'key') + + url = '{0}/{1}/description'.format(table_endpoint, table_key) + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + description = response.json().get('description') + else: + message = 'Get table description failed' + description = None + + payload = jsonify({'description': description, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'description': None, 'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_column_description', methods=['GET']) +def get_column_description() -> Response: + try: + table_endpoint = _get_table_endpoint() + table_key = get_query_param(request.args, 'key') + + column_name = get_query_param(request.args, 'column_name') + + url = '{0}/{1}/column/{2}/description'.format(table_endpoint, table_key, column_name) + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + description = response.json().get('description') + else: + message = 'Get column description failed' + description = None + + payload = jsonify({'description': description, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'description': None, 'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_type_metadata_description', methods=['GET']) +def get_type_metadata_description() -> Response: + try: + type_metadata_endpoint = _get_type_metadata_endpoint() + + type_metadata_key = get_query_param(request.args, 'type_metadata_key') + + url = '{0}/{1}/description'.format(type_metadata_endpoint, type_metadata_key) + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + description = response.json().get('description') + else: + message = 'Get type metadata description failed' + description = None + + payload = jsonify({'description': description, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'description': None, 'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/put_table_description', methods=['PUT']) +def put_table_description() -> Response: + + @action_logging + def _log_put_table_description(*, table_key: str, description: str, source: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + table_endpoint = _get_table_endpoint() + + table_key = get_query_param(args, 'key') + + description = get_query_param(args, 'description') + src = get_query_param(args, 'source') + + table_uri = TableUri.from_uri(table_key) + if not is_table_editable(table_uri.schema, table_uri.table): + return make_response('', HTTPStatus.FORBIDDEN) + + url = '{0}/{1}/description'.format(table_endpoint, table_key) + _log_put_table_description(table_key=table_key, description=description, source=src) + + response = request_metadata(url=url, method='PUT', data=json.dumps({'description': description})) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Update table description failed' + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/put_column_description', methods=['PUT']) +def put_column_description() -> Response: + + @action_logging + def _log_put_column_description(*, table_key: str, column_name: str, description: str, source: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + + table_key = get_query_param(args, 'key') + table_endpoint = _get_table_endpoint() + + column_name = get_query_param(args, 'column_name') + description = get_query_param(args, 'description') + + src = get_query_param(args, 'source') + + table_uri = TableUri.from_uri(table_key) + if not is_table_editable(table_uri.schema, table_uri.table): + return make_response('', HTTPStatus.FORBIDDEN) + + url = '{0}/{1}/column/{2}/description'.format(table_endpoint, table_key, column_name) + _log_put_column_description(table_key=table_key, column_name=column_name, description=description, source=src) + + response = request_metadata(url=url, method='PUT', data=json.dumps({'description': description})) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Update column description failed' + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/put_type_metadata_description', methods=['PUT']) +def put_type_metadata_description() -> Response: + + @action_logging + def _log_put_type_metadata_description(*, type_metadata_key: str, description: str, source: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + + type_metadata_endpoint = _get_type_metadata_endpoint() + + type_metadata_key = get_query_param(args, 'type_metadata_key') + description = get_query_param(args, 'description') + + src = get_query_param(args, 'source') + + table_key = get_query_param(args, 'table_key') + table_uri = TableUri.from_uri(table_key) + if not is_table_editable(table_uri.schema, table_uri.table): + return make_response('', HTTPStatus.FORBIDDEN) + + url = '{0}/{1}/description'.format(type_metadata_endpoint, type_metadata_key) + _log_put_type_metadata_description(type_metadata_key=type_metadata_key, description=description, source=src) + + response = request_metadata(url=url, method='PUT', data=json.dumps({'description': description})) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Update type metadata description failed' + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/tags') +def get_tags() -> Response: + """ + call the metadata service endpoint to get the list of all tags from metadata proxy + :return: a json output containing the list of all tags, as 'tags' + """ + try: + url = app.config['METADATASERVICE_BASE'] + TAGS_ENDPOINT + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + tags = response.json().get('tag_usages') + else: + message = 'Encountered error: Tags Unavailable' + logging.error(message) + tags = [] + + payload = jsonify({'tags': tags, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + payload = jsonify({'tags': [], 'msg': message}) + logging.exception(message) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/badges') +def get_badges() -> Response: + """ + call the metadata service endpoint to get the list of all badges from metadata proxy + :return: a json output containing the list of all badges, as 'badges' + """ + try: + url = app.config['METADATASERVICE_BASE'] + BADGES_ENDPOINT + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + badges = response.json().get('badges') + else: + message = 'Encountered error: Badges Unavailable' + logging.error(message) + badges = [] + + payload = jsonify({'badges': badges, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + payload = jsonify({'badges': [], 'msg': message}) + logging.exception(message) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +def _update_metadata_tag(table_key: str, method: str, tag: str) -> int: + table_endpoint = _get_table_endpoint() + url = f'{table_endpoint}/{table_key}/tag/{tag}' + response = request_metadata(url=url, method=method) + status_code = response.status_code + if status_code != HTTPStatus.OK: + LOGGER.info(f'Fail to update tag in metadataservice, http status code: {status_code}') + LOGGER.debug(response.text) + return status_code + + +@metadata_blueprint.route('/update_table_tags', methods=['PUT', 'DELETE']) +def update_table_tags() -> Response: + + @action_logging + def _log_update_table_tags(*, table_key: str, method: str, tag: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + method = request.method + + table_key = get_query_param(args, 'key') + + tag = get_query_param(args, 'tag') + + _log_update_table_tags(table_key=table_key, method=method, tag=tag) + + metadata_status_code = _update_metadata_tag(table_key=table_key, method=method, tag=tag) + + search_method = method if method == 'DELETE' else 'POST' + update_request = UpdateDocumentRequest(resource_key=table_key, + resource_type=ResourceType.Table.name.lower(), + field='tag', + value=tag, + operation='add') + request_json = json.dumps(UpdateDocumentRequestSchema().dump(update_request)) + + search_status_code = execute_search_document_request(request_json=request_json, + method=search_method) + + http_status_code = HTTPStatus.OK + if metadata_status_code == HTTPStatus.OK and search_status_code == HTTPStatus.OK: + message = 'Success' + else: + message = f'Encountered error: {method} table tag failed' + logging.error(message) + http_status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + payload = jsonify({'msg': message}) + return make_response(payload, http_status_code) + + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/update_dashboard_tags', methods=['PUT', 'DELETE']) +def update_dashboard_tags() -> Response: + + @action_logging + def _log_update_dashboard_tags(*, uri_key: str, method: str, tag: str) -> None: + pass # pragma: no cover + + try: + args = request.get_json() + method = request.method + + dashboard_endpoint = _get_dashboard_endpoint() + uri_key = get_query_param(args, 'key') + tag = get_query_param(args, 'tag') + url = f'{dashboard_endpoint}/{uri_key}/tag/{tag}' + + _log_update_dashboard_tags(uri_key=uri_key, method=method, tag=tag) + + response = request_metadata(url=url, method=method) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = f'Encountered error: {method} dashboard tag failed' + logging.error(message) + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/user', methods=['GET']) +def get_user() -> Response: + + @action_logging + def _log_get_user(*, user_id: str, index: Optional[int], source: Optional[str]) -> None: + pass # pragma: no cover + + try: + user_id = get_query_param(request.args, 'user_id') + index = request.args.get('index', None) + source = request.args.get('source', None) + + url = '{0}{1}/{2}'.format(app.config['METADATASERVICE_BASE'], USER_ENDPOINT, user_id) + _log_get_user(user_id=user_id, index=index, source=source) + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Encountered error: failed to fetch user with user_id: {0}'.format(user_id) + logging.error(message) + + payload = { + 'msg': message, + 'user': dump_user(load_user(response.json())), + } + return make_response(jsonify(payload), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = {'msg': message} + return make_response(jsonify(payload), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/user/bookmark', methods=['GET']) +def get_bookmark() -> Response: + """ + Call metadata service to fetch a specified user's bookmarks. + If no 'user_id' is specified, it will fetch the logged-in user's bookmarks + :param user_id: (optional) the user whose bookmarks are fetched. + :return: a JSON object with an array of bookmarks under 'bookmarks' key + """ + try: + user_id = request.args.get('user_id') + if user_id is None: + if app.config['AUTH_USER_METHOD']: + user_id = app.config['AUTH_USER_METHOD'](app).user_id + else: + raise Exception('AUTH_USER_METHOD is not configured') + + url = '{0}{1}/{2}/follow/'.format(app.config['METADATASERVICE_BASE'], USER_ENDPOINT, user_id) + + response = request_metadata(url=url, method=request.method) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + tables = response.json().get('table') + table_bookmarks = [marshall_table_partial(table) for table in tables] + dashboards = response.json().get('dashboard', []) + dashboard_bookmarks = [marshall_dashboard_partial(dashboard) for dashboard in dashboards] + else: + message = f'Encountered error: failed to get bookmark for user_id: {user_id}' + logging.error(message) + table_bookmarks = [] + dashboard_bookmarks = [] + + all_bookmarks = { + 'table': table_bookmarks, + 'dashboard': dashboard_bookmarks + } + return make_response(jsonify({'msg': message, 'bookmarks': all_bookmarks}), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/user/bookmark', methods=['PUT', 'DELETE']) +def update_bookmark() -> Response: + """ + Call metadata service to PUT or DELETE a bookmark + Params + :param type: Resource type for the bookmarked item. e.g. 'table' + :param key: Resource key for the bookmarked item. + :return: + """ + + @action_logging + def _log_update_bookmark(*, resource_key: str, resource_type: str, method: str) -> None: + pass # pragma: no cover + + try: + if app.config['AUTH_USER_METHOD']: + user = app.config['AUTH_USER_METHOD'](app) + else: + raise Exception('AUTH_USER_METHOD is not configured') + + args = request.get_json() + resource_type = get_query_param(args, 'type') + resource_key = get_query_param(args, 'key') + + url = '{0}{1}/{2}/follow/{3}/{4}'.format(app.config['METADATASERVICE_BASE'], + USER_ENDPOINT, + user.user_id, + resource_type, + resource_key) + + _log_update_bookmark(resource_key=resource_key, resource_type=resource_type, method=request.method) + + response = request_metadata(url=url, method=request.method) + status_code = response.status_code + + return make_response(jsonify({'msg': 'success', 'response': response.json()}), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/user/read', methods=['GET']) +def get_user_read() -> Response: + """ + Calls metadata service to GET read/frequently used resources + :return: a JSON object with an array of read resources + """ + try: + user_id = get_query_param(request.args, 'user_id') + + url = '{0}{1}/{2}/read/'.format(app.config['METADATASERVICE_BASE'], + USER_ENDPOINT, + user_id) + response = request_metadata(url=url, method=request.method) + status_code = response.status_code + read_tables_raw = response.json().get('table') + read_tables = [marshall_table_partial(table) for table in read_tables_raw] + return make_response(jsonify({'msg': 'success', 'read': read_tables}), status_code) + + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/user/own', methods=['GET']) +def get_user_own() -> Response: + """ + Calls metadata service to GET owned resources + :return: a JSON object with an array of owned resources + """ + try: + user_id = get_query_param(request.args, 'user_id') + + url = '{0}{1}/{2}/own/'.format(app.config['METADATASERVICE_BASE'], + USER_ENDPOINT, + user_id) + response = request_metadata(url=url, method=request.method) + status_code = response.status_code + owned_tables_raw = response.json().get('table') + owned_tables = [marshall_table_partial(table) for table in owned_tables_raw] + dashboards = response.json().get('dashboard', []) + owned_dashboards = [marshall_dashboard_partial(dashboard) for dashboard in dashboards] + all_owned = { + 'table': owned_tables, + 'dashboard': owned_dashboards + } + return make_response(jsonify({'msg': 'success', 'own': all_owned}), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/dashboard', methods=['GET']) +def get_dashboard_metadata() -> Response: + """ + Call metadata service endpoint to fetch specified dashboard metadata + :return: + """ + @action_logging + def _get_dashboard_metadata(*, uri: str, index: int, source: str) -> None: + pass # pragma: no cover + + try: + uri = get_query_param(request.args, 'uri') + index = request.args.get('index', None) + source = request.args.get('source', None) + _get_dashboard_metadata(uri=uri, index=index, source=source) + + url = f'{app.config["METADATASERVICE_BASE"]}{DASHBOARD_ENDPOINT}/{uri}' + + response = request_metadata(url=url, method=request.method) + dashboard = marshall_dashboard_full(response.json()) + status_code = response.status_code + return make_response(jsonify({'msg': 'success', 'dashboard': dashboard}), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'dashboard': {}, 'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/table//dashboards', methods=['GET']) +def get_related_dashboard_metadata(table_key: str) -> Response: + """ + Call metadata service endpoint to fetch related dashboard metadata + :return: + """ + try: + url = f'{app.config["METADATASERVICE_BASE"]}{TABLE_ENDPOINT}/{table_key}/dashboard/' + results_dict = _get_related_dashboards_metadata(url=url) + return make_response(jsonify(results_dict), results_dict.get('status_code', HTTPStatus.INTERNAL_SERVER_ERROR)) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'dashboards': [], 'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@action_logging +def _get_related_dashboards_metadata(*, url: str) -> Dict[str, Any]: + + results_dict: Dict[str, Any] = { + 'dashboards': [], + 'msg': '', + } + + try: + response = request_metadata(url=url) + except ValueError as e: + # envoy client BadResponse is a subclass of ValueError + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + logging.exception(message) + return results_dict + + status_code = response.status_code + results_dict['status_code'] = status_code + + if status_code != HTTPStatus.OK: + message = f'Encountered {status_code} Error: Related dashboard metadata request failed' + results_dict['msg'] = message + logging.error(message) + return results_dict + + try: + dashboard_data_raw = response.json().get('dashboards', []) + return { + 'dashboards': [marshall_dashboard_partial(dashboard) for dashboard in dashboard_data_raw], + 'msg': 'Success', + 'status_code': status_code + } + except Exception as e: + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + logging.exception(message) + # explicitly raise the exception which will trigger 500 api response + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + return results_dict + + +@metadata_blueprint.route('/get_table_lineage', methods=['GET']) +def get_table_lineage() -> Response: + """ + Call metadata service to fetch table lineage for a given table + :return: + """ + try: + table_endpoint = _get_table_endpoint() + table_key = get_query_param(request.args, 'key') + depth = get_query_param(request.args, 'depth') + direction = get_query_param(request.args, 'direction') + url = f'{table_endpoint}/{table_key}/lineage?depth={depth}&direction={direction}' + response = request_metadata(url=url, method=request.method) + json = response.json() + downstream = [marshall_lineage_table(table) for table in json.get('downstream_entities')] + upstream = [marshall_lineage_table(table) for table in json.get('upstream_entities')] + downstream_count = json.get('downstream_count') + upstream_count = json.get('upstream_count') + + payload = { + 'downstream_entities': downstream, + 'upstream_entities': upstream, + 'downstream_count': downstream_count, + 'upstream_count': upstream_count, + } + return make_response(jsonify(payload), 200) + except Exception as e: + payload = {'msg': 'Encountered exception: ' + str(e)} + return make_response(jsonify(payload), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_column_lineage', methods=['GET']) +def get_column_lineage() -> Response: + """ + Call metadata service to fetch table lineage for a given table + :return: + """ + try: + table_endpoint = _get_table_endpoint() + table_key = get_query_param(request.args, 'key') + column_name = get_query_param(request.args, 'column_name') + url = f'{table_endpoint}/{table_key}/column/{column_name}/lineage' + response = request_metadata(url=url, method=request.method) + json = response.json() + downstream = [marshall_lineage_table(table) for table in json.get('downstream_entities')] + upstream = [marshall_lineage_table(table) for table in json.get('upstream_entities')] + downstream_count = json.get('downstream_count') + upstream_count = json.get('upstream_count') + + payload = { + 'downstream_entities': downstream, + 'upstream_entities': upstream, + 'downstream_count': downstream_count, + 'upstream_count': upstream_count, + } + return make_response(jsonify(payload), 200) + except Exception as e: + payload = {'msg': 'Encountered exception: ' + str(e)} + return make_response(jsonify(payload), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_feature_description', methods=['GET']) +def get_feature_description() -> Response: + try: + feature_key = get_query_param(request.args, 'key') + + endpoint = _get_feature_endpoint() + + url = '{0}/{1}/description'.format(endpoint, feature_key) + + response = request_metadata(url=url) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + description = response.json().get('description') + else: + message = 'Get feature description failed' + description = None + + payload = jsonify({'description': description, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'description': None, 'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/put_feature_description', methods=['PUT']) +def put_feature_description() -> Response: + try: + args = request.get_json() + feature_key = get_query_param(args, 'key') + description = get_query_param(args, 'description') + + endpoint = _get_feature_endpoint() + + url = '{0}/{1}/description'.format(endpoint, feature_key) + + response = request_metadata(url=url, method='PUT', data=json.dumps({'description': description})) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Update feature description failed' + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_feature_generation_code', methods=['GET']) +def get_feature_generation_code() -> Response: + """ + Call metadata service to fetch feature generation code + :return: + """ + try: + feature_key = get_query_param(request.args, 'key') + + endpoint = _get_feature_endpoint() + + url = f'{endpoint}/{feature_key}/generation_code' + response = request_metadata(url=url, method=request.method) + payload = response.json() + return make_response(jsonify(payload), 200) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/get_feature_lineage', methods=['GET']) +def get_feature_lineage() -> Response: + """ + Call metadata service to fetch table lineage for a given feature + :return: + """ + try: + feature_key = get_query_param(request.args, 'key') + depth = get_query_param(request.args, 'depth') + direction = get_query_param(request.args, 'direction') + + endpoint = _get_feature_endpoint() + + url = f'{endpoint}/{feature_key}/lineage?depth={depth}&direction={direction}' + response = request_metadata(url=url, method=request.method) + json = response.json() + downstream = [marshall_lineage_table(table) for table in json.get('downstream_entities')] + upstream = [marshall_lineage_table(table) for table in json.get('upstream_entities')] + downstream_count = json.get('downstream_count') + upstream_count = json.get('upstream_count') + + payload = { + 'downstream_entities': downstream, + 'upstream_entities': upstream, + 'downstream_count': downstream_count, + 'upstream_count': upstream_count, + } + return make_response(jsonify(payload), 200) + except Exception as e: + payload = {'msg': 'Encountered exception: ' + str(e)} + return make_response(jsonify(payload), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/update_feature_owner', methods=['PUT', 'DELETE']) +def update_feature_owner() -> Response: + try: + args = request.get_json() + feature_key = get_query_param(args, 'key') + owner = get_query_param(args, 'owner') + + endpoint = _get_feature_endpoint() + + url = '{0}/{1}/owner/{2}'.format(endpoint, feature_key, owner) + method = request.method + + response = request_metadata(url=url, method=method) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + message = 'Updated owner' + else: + message = 'There was a problem updating owner {0}'.format(owner) + + payload = jsonify({'msg': message}) + return make_response(payload, status_code) + except Exception as e: + payload = jsonify({'msg': 'Encountered exception: ' + str(e)}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +def _update_metadata_feature_tag(endpoint: str, feature_key: str, method: str, tag: str) -> int: + url = f'{endpoint}/{feature_key}/tag/{tag}' + response = request_metadata(url=url, method=method) + status_code = response.status_code + if status_code != HTTPStatus.OK: + LOGGER.info(f'Fail to update tag in metadataservice, http status code: {status_code}') + LOGGER.debug(response.text) + return status_code + + +@metadata_blueprint.route('/update_feature_tags', methods=['PUT', 'DELETE']) +def update_feature_tags() -> Response: + try: + args = request.get_json() + method = request.method + feature_key = get_query_param(args, 'key') + tag = get_query_param(args, 'tag') + + endpoint = _get_feature_endpoint() + + metadata_status_code = _update_metadata_feature_tag(endpoint=endpoint, + feature_key=feature_key, + method=method, tag=tag) + + search_method = method if method == 'DELETE' else 'POST' + update_request = UpdateDocumentRequest(resource_key=feature_key, + resource_type=ResourceType.Feature.name.lower(), + field='tags', + value=tag, + operation='add') + request_json = json.dumps(UpdateDocumentRequestSchema().dump(update_request)) + + search_status_code = execute_search_document_request(request_json=request_json, + method=search_method) + http_status_code = HTTPStatus.OK + if metadata_status_code == HTTPStatus.OK and search_status_code == HTTPStatus.OK: + message = 'Success' + else: + message = f'Encountered error: {method} feature tag failed' + logging.error(message) + http_status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + payload = jsonify({'msg': message}) + return make_response(payload, http_status_code) + + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@metadata_blueprint.route('/feature', methods=['GET']) +def get_feature_metadata() -> Response: + """ + call the metadata service endpoint and return matching results + :return: a json output containing a feature metadata object as 'featureData' + + """ + try: + feature_key = get_query_param(request.args, 'key') + list_item_index = request.args.get('index', None) + list_item_source = request.args.get('source', None) + + results_dict = _get_feature_metadata(feature_key=feature_key, index=list_item_index, source=list_item_source) + return make_response(jsonify(results_dict), results_dict.get('status_code', HTTPStatus.INTERNAL_SERVER_ERROR)) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'featureData': {}, 'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@action_logging +def _get_feature_metadata(*, feature_key: str, index: int, source: str) -> Dict[str, Any]: + + results_dict: Dict[str, Any] = { + 'featureData': {}, + 'msg': '', + } + + try: + feature_endpoint = _get_feature_endpoint() + url = f'{feature_endpoint}/{feature_key}' + response = request_metadata(url=url) + except ValueError as e: + # envoy client BadResponse is a subclass of ValueError + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + logging.exception(message) + return results_dict + + status_code = response.status_code + results_dict['status_code'] = status_code + + if status_code != HTTPStatus.OK: + message = 'Encountered error: Metadata request failed' + results_dict['msg'] = message + logging.error(message) + return results_dict + + try: + feature_data_raw: dict = response.json() + + feature_data_raw['key'] = feature_key + + results_dict['featureData'] = marshall_feature_full(feature_data_raw) + results_dict['msg'] = 'Success' + return results_dict + except Exception as e: + message = 'Encountered exception: ' + str(e) + results_dict['msg'] = message + logging.exception(message) + # explicitly raise the exception which will trigger 500 api response + results_dict['status_code'] = getattr(e, 'code', HTTPStatus.INTERNAL_SERVER_ERROR) + return results_dict diff --git a/frontend/amundsen_application/api/notice/__init__.py b/frontend/amundsen_application/api/notice/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/notice/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/notice/v0.py b/frontend/amundsen_application/api/notice/v0.py new file mode 100644 index 0000000000..5c8f045281 --- /dev/null +++ b/frontend/amundsen_application/api/notice/v0.py @@ -0,0 +1,65 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import json + +from http import HTTPStatus +from typing import cast + + +from flask import Response, jsonify, make_response, request, current_app as app +from flask.blueprints import Blueprint +from marshmallow import ValidationError +from werkzeug.utils import import_string + +from amundsen_application.api.utils.request_utils import get_query_param +from amundsen_application.base.base_notice_client import BaseNoticeClient + +LOGGER = logging.getLogger(__name__) +NOTICE_CLIENT_INSTANCE = None + +notices_blueprint = Blueprint('notices', __name__, url_prefix='/api/notices/v0') + + +def get_notice_client() -> BaseNoticeClient: + global NOTICE_CLIENT_INSTANCE + if NOTICE_CLIENT_INSTANCE is None and app.config['NOTICE_CLIENT'] is not None: + notice_client_class = import_string(app.config['NOTICE_CLIENT']) + NOTICE_CLIENT_INSTANCE = notice_client_class() + return cast(BaseNoticeClient, NOTICE_CLIENT_INSTANCE) + + +@notices_blueprint.route('/table', methods=['GET']) +def get_table_notices_summary() -> Response: + global NOTICE_CLIENT_INSTANCE + try: + client = get_notice_client() + if client is not None: + return _get_table_notices_summary_client() + payload = jsonify({'notices': {}, 'msg': 'A client for retrieving resource notices must be configured'}) + return make_response(payload, HTTPStatus.NOT_IMPLEMENTED) + except Exception as e: + message = 'Encountered exception: ' + str(e) + LOGGER.exception(message) + payload = jsonify({'notices': {}, 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +def _get_table_notices_summary_client() -> Response: + client = get_notice_client() + table_key = get_query_param(request.args, 'key') + response = client.get_table_notices_summary(table_key=table_key) + status_code = response.status_code + if status_code == HTTPStatus.OK: + try: + notices = json.loads(response.data).get('notices') + payload = jsonify({'notices': notices, 'msg': 'Success'}) + except ValidationError as err: + LOGGER.info('Notices data dump returned errors: ' + str(err.messages)) + raise Exception(f"Notices client didn't return a valid ResourceNotice object. {err}") + else: + message = f'Encountered error: Notice client request failed with code {status_code}' + LOGGER.error(message) + payload = jsonify({'notices': {}, 'msg': message}) + return make_response(payload, status_code) diff --git a/frontend/amundsen_application/api/preview/__init__.py b/frontend/amundsen_application/api/preview/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/preview/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/preview/dashboard/__init__.py b/frontend/amundsen_application/api/preview/dashboard/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/preview/dashboard/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/__init__.py b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/mode_preview.py b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/mode_preview.py new file mode 100644 index 0000000000..27b0abbed1 --- /dev/null +++ b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/mode_preview.py @@ -0,0 +1,119 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional, Any + +import requests +from flask import has_app_context, current_app as app +from requests.auth import HTTPBasicAuth +from retrying import retry + +from amundsen_application.api.metadata.v0 import USER_ENDPOINT +from amundsen_application.api.utils.request_utils import request_metadata +from amundsen_application.base.base_preview import BasePreview +from amundsen_application.models.user import load_user + +LOGGER = logging.getLogger(__name__) +DEFAULT_REPORT_URL_TEMPLATE = 'https://app.mode.com/api/{organization}/reports/{dashboard_id}' + + +def _validate_not_none(var: Any, var_name: str) -> Any: + if not var: + raise ValueError('{} is missing'.format(var_name)) + return var + + +def _retry_on_retriable_error(exception: Exception) -> bool: + return not isinstance(exception, PermissionError) + + +class ModePreview(BasePreview): + """ + A class to get Mode Dashboard preview image + """ + + def __init__(self, *, + access_token: Optional[str] = None, + password: Optional[str] = None, + organization: Optional[str] = None, + report_url_template: Optional[str] = None): + self._access_token = access_token if access_token else app.config['CREDENTIALS_MODE_ADMIN_TOKEN'] + _validate_not_none(self._access_token, 'access_token') + self._password = password if password else app.config['CREDENTIALS_MODE_ADMIN_PASSWORD'] + _validate_not_none(self._password, 'password') + self._organization = organization if organization else app.config['MODE_ORGANIZATION'] + _validate_not_none(self._organization, 'organization') + + self._report_url_template = report_url_template if report_url_template else DEFAULT_REPORT_URL_TEMPLATE + + if has_app_context() and app.config['MODE_REPORT_URL_TEMPLATE'] is not None: + self._report_url_template = app.config['MODE_REPORT_URL_TEMPLATE'] + + self._is_auth_enabled = False + if has_app_context() and app.config['ACL_ENABLED_DASHBOARD_PREVIEW']: + if not app.config['AUTH_USER_METHOD']: + raise Exception('AUTH_USER_METHOD must be configured to enable ACL_ENABLED_DASHBOARD_PREVIEW') + self._is_auth_enabled = self.__class__.__name__ in app.config['ACL_ENABLED_DASHBOARD_PREVIEW'] + self._auth_user_method = app.config['AUTH_USER_METHOD'] + + @retry(stop_max_attempt_number=3, wait_random_min=500, wait_random_max=1000, + retry_on_exception=_retry_on_retriable_error) + def get_preview_image(self, *, uri: str) -> bytes: + """ + Retrieves short lived URL that provides Mode report preview, downloads it and returns it's bytes + :param uri: + :return: image bytes + :raise: PermissionError when user is not allowed to access the dashboard + """ + if self._is_auth_enabled: + self._authorize_access(user_id=self._auth_user_method(app).user_id) + + url = self._get_preview_image_url(uri=uri) + r = requests.get(url, allow_redirects=True) + r.raise_for_status() + + return r.content + + def _get_preview_image_url(self, *, uri: str) -> str: + url = self._report_url_template.format(organization=self._organization, dashboard_id=uri.split('/')[-1]) + + LOGGER.info('Calling URL {} to fetch preview image URL'.format(url)) + response = requests.get(url, auth=HTTPBasicAuth(self._access_token, self._password)) + if response.status_code == 404: + raise FileNotFoundError('Dashboard {} not found. Possibly has been deleted.'.format(uri)) + + response.raise_for_status() + + web_preview_image_key = 'web_preview_image' + result = response.json() + + if web_preview_image_key not in result: + raise FileNotFoundError('No preview image available on {}'.format(uri)) + + image_url = result[web_preview_image_key] + if image_url is None: + raise FileNotFoundError('No preview image available on {}'.format(uri)) + + return image_url + + def _authorize_access(self, user_id: str) -> None: + """ + Get Mode user ID via metadata service. Note that metadata service needs to be at least v2.5.2 and + Databuilder should also have ingested Mode user. + https://github.com/lyft/amundsendatabuilder#modedashboarduserextractor + + :param user_id: + :return: + :raise: PermissionError when user is not allowed to access the dashboard + """ + + metadata_svc_url = '{0}{1}/{2}'.format(app.config['METADATASERVICE_BASE'], USER_ENDPOINT, user_id) + response = request_metadata(url=metadata_svc_url) + response.raise_for_status() + + user = load_user(response.json()) + if user.is_active and user.other_key_values and user.other_key_values.get('mode_user_id'): + return + + raise PermissionError('User {} is not authorized to preview Mode Dashboard'.format(user_id)) diff --git a/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/preview_factory_method.py b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/preview_factory_method.py new file mode 100644 index 0000000000..a4fbec30d0 --- /dev/null +++ b/frontend/amundsen_application/api/preview/dashboard/dashboard_preview/preview_factory_method.py @@ -0,0 +1,43 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from abc import ABCMeta, abstractmethod + +from amundsen_application.base.base_preview import BasePreview +from amundsen_application.api.preview.dashboard.dashboard_preview.mode_preview import ModePreview + +LOGGER = logging.getLogger(__name__) + + +class BasePreviewMethodFactory(metaclass=ABCMeta): + + @abstractmethod + def get_instance(self, *, uri: str) -> BasePreview: + """ + Provides an instance of BasePreview based on uri + :param uri: + :return: + """ + pass + + +class DefaultPreviewMethodFactory(BasePreviewMethodFactory): + + def __init__(self) -> None: + # Register preview clients here. Key: product, Value: BasePreview implementation + self._object_map = { + 'mode': ModePreview() + } + LOGGER.info('Supported products: {}'.format(list(self._object_map.keys()))) + + def get_instance(self, *, uri: str) -> BasePreview: + product = self.get_product(uri=uri) + + if product in self._object_map: + return self._object_map[product] + + raise NotImplementedError('Product {} is not supported'.format(product)) + + def get_product(self, *, uri: str) -> str: + return uri.split('_')[0] diff --git a/frontend/amundsen_application/api/preview/dashboard/v0.py b/frontend/amundsen_application/api/preview/dashboard/v0.py new file mode 100644 index 0000000000..8ce057beef --- /dev/null +++ b/frontend/amundsen_application/api/preview/dashboard/v0.py @@ -0,0 +1,58 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import io +import logging +from http import HTTPStatus + +from flask import send_file, jsonify, make_response, Response, current_app as app +from flask.blueprints import Blueprint + +from amundsen_application.api.preview.dashboard.dashboard_preview.preview_factory_method import \ + DefaultPreviewMethodFactory, BasePreviewMethodFactory + +LOGGER = logging.getLogger(__name__) +PREVIEW_FACTORY: BasePreviewMethodFactory = None # type: ignore + +dashboard_preview_blueprint = Blueprint('dashboard_preview', __name__, url_prefix='/api/dashboard_preview/v0') + + +def initialize_preview_factory_class() -> None: + """ + Instantiates Preview factory class and assign it to PREVIEW_FACTORY + :return: None + """ + global PREVIEW_FACTORY + + PREVIEW_FACTORY = app.config['DASHBOARD_PREVIEW_FACTORY'] + if not PREVIEW_FACTORY: + PREVIEW_FACTORY = DefaultPreviewMethodFactory() + + LOGGER.info('Using {} for Dashboard'.format(PREVIEW_FACTORY)) + + +@dashboard_preview_blueprint.route('/dashboard//preview.jpg', methods=['GET']) +def get_preview_image(uri: str) -> Response: + """ + Provides preview image of Dashboard which can be cached for a day (by default). + :return: + """ + + if not PREVIEW_FACTORY: + LOGGER.info('Initializing Dashboard PREVIEW_FACTORY') + initialize_preview_factory_class() + + preview_client = PREVIEW_FACTORY.get_instance(uri=uri) + try: + return send_file(io.BytesIO(preview_client.get_preview_image(uri=uri)), + mimetype='image/jpeg', + max_age=app.config['DASHBOARD_PREVIEW_IMAGE_CACHE_MAX_AGE_SECONDS']) + except FileNotFoundError as fne: + LOGGER.exception('FileNotFoundError on get_preview_image') + return make_response(jsonify({'msg': fne.args[0]}), HTTPStatus.NOT_FOUND) + except PermissionError as pe: + LOGGER.exception('PermissionError on get_preview_image') + return make_response(jsonify({'msg': pe.args[0]}), HTTPStatus.UNAUTHORIZED) + except Exception as e: + LOGGER.exception('Unexpected failure on get_preview_image') + return make_response(jsonify({'msg': 'Encountered exception: ' + str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/preview/v0.py b/frontend/amundsen_application/api/preview/v0.py new file mode 100644 index 0000000000..5ee9d3ad97 --- /dev/null +++ b/frontend/amundsen_application/api/preview/v0.py @@ -0,0 +1,112 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from pkg_resources import iter_entry_points + +from http import HTTPStatus + +from flask import Response, jsonify, make_response, request, current_app as app +from flask.blueprints import Blueprint +from marshmallow import ValidationError +from werkzeug.utils import import_string + +from amundsen_application.models.preview_data import PreviewDataSchema + +LOGGER = logging.getLogger(__name__) +PREVIEW_CLIENT_CLASS = None +PREVIEW_CLIENT_INSTANCE = None + +for entry_point in iter_entry_points(group='preview_client', name='table_preview_client_class'): + preview_client_class = entry_point.load() + if preview_client_class is not None: + PREVIEW_CLIENT_CLASS = preview_client_class + +preview_blueprint = Blueprint('preview', __name__, url_prefix='/api/preview/v0') + + +@preview_blueprint.route('/', methods=['POST']) +def get_table_preview() -> Response: + global PREVIEW_CLIENT_INSTANCE + global PREVIEW_CLIENT_CLASS + try: + if PREVIEW_CLIENT_INSTANCE is None: + if PREVIEW_CLIENT_CLASS is not None: + PREVIEW_CLIENT_INSTANCE = PREVIEW_CLIENT_CLASS() + logging.warn('Setting preview_client via entry_point is DEPRECATED and ' + 'will be removed in a future version') + elif (app.config['PREVIEW_CLIENT_ENABLED'] + and app.config['PREVIEW_CLIENT'] is not None): + PREVIEW_CLIENT_CLASS = import_string(app.config['PREVIEW_CLIENT']) + PREVIEW_CLIENT_INSTANCE = PREVIEW_CLIENT_CLASS() + else: + payload = jsonify({'previewData': {}, 'msg': 'A client for the preview feature must be configured'}) + return make_response(payload, HTTPStatus.NOT_IMPLEMENTED) + + response = PREVIEW_CLIENT_INSTANCE.get_preview_data(params=request.get_json()) + status_code = response.status_code + + preview_data = json.loads(response.data).get('preview_data') + if status_code == HTTPStatus.OK: + # validate the returned table preview data + try: + data = PreviewDataSchema().load(preview_data) + payload = jsonify({'previewData': data, 'msg': 'Success'}) + except ValidationError as err: + logging.error('Preview data dump returned errors: ' + str(err.messages)) + raise Exception('The preview client did not return a valid PreviewData object') + else: + message = 'Encountered error: Preview client request failed with code ' + str(status_code) + logging.error(message) + # only necessary to pass the error text + payload = jsonify({'previewData': {'error_text': preview_data.get('error_text', '')}, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = f'Encountered exception: {str(e)}' + logging.exception(message) + payload = jsonify({'previewData': {}, 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +@preview_blueprint.route('/feature_preview', methods=['POST']) +def get_feature_preview() -> Response: + global PREVIEW_CLIENT_INSTANCE + global PREVIEW_CLIENT_CLASS + try: + if PREVIEW_CLIENT_INSTANCE is None: + if PREVIEW_CLIENT_CLASS is not None: + PREVIEW_CLIENT_INSTANCE = PREVIEW_CLIENT_CLASS() + logging.warn('Setting preview_client via entry_point is DEPRECATED and ' + 'will be removed in a future version') + elif (app.config['PREVIEW_CLIENT_ENABLED'] + and app.config['PREVIEW_CLIENT'] is not None): + PREVIEW_CLIENT_CLASS = import_string(app.config['PREVIEW_CLIENT']) + PREVIEW_CLIENT_INSTANCE = PREVIEW_CLIENT_CLASS() + else: + payload = jsonify({'previewData': {}, 'msg': 'A client for the preview feature must be configured'}) + return make_response(payload, HTTPStatus.NOT_IMPLEMENTED) + + response = PREVIEW_CLIENT_INSTANCE.get_feature_preview_data(params=request.get_json()) + status_code = response.status_code + + preview_data = json.loads(response.data).get('preview_data') + if status_code == HTTPStatus.OK: + # validate the returned feature preview data + try: + data = PreviewDataSchema().load(preview_data) + payload = jsonify({'previewData': data, 'msg': 'Success'}) + except ValidationError as err: + logging.error('Preview data dump returned errors: ' + str(err.messages)) + raise Exception('The preview client did not return a valid PreviewData object') + else: + message = 'Encountered error: Preview client request failed with code ' + str(status_code) + logging.error(message) + # only necessary to pass the error text + payload = jsonify({'previewData': {'error_text': preview_data.get('error_text', '')}, 'msg': message}) + return make_response(payload, status_code) + except Exception as e: + message = f'Encountered exception: {str(e)}' + logging.exception(message) + payload = jsonify({'previewData': {}, 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/quality/__init__.py b/frontend/amundsen_application/api/quality/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/quality/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/quality/v0.py b/frontend/amundsen_application/api/quality/v0.py new file mode 100644 index 0000000000..6376756387 --- /dev/null +++ b/frontend/amundsen_application/api/quality/v0.py @@ -0,0 +1,65 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging + +from http import HTTPStatus +from typing import cast + +from flask import Response, jsonify, make_response, request, current_app as app +from flask.blueprints import Blueprint +from marshmallow import ValidationError +from werkzeug.utils import import_string + +from amundsen_application.api.utils.request_utils import get_query_param +from amundsen_application.base.base_quality_client import BaseQualityClient + +LOGGER = logging.getLogger(__name__) +QUALITY_CLIENT_INSTANCE = None + +quality_blueprint = Blueprint('quality', __name__, url_prefix='/api/quality/v0') + + +def get_quality_client() -> BaseQualityClient: + global QUALITY_CLIENT_INSTANCE + if QUALITY_CLIENT_INSTANCE is None and app.config['QUALITY_CLIENT'] is not None: + quality_client_class = import_string(app.config['QUALITY_CLIENT']) + QUALITY_CLIENT_INSTANCE = quality_client_class() + return cast(BaseQualityClient, QUALITY_CLIENT_INSTANCE) + + +@quality_blueprint.route('/table/summary', methods=['GET']) +def get_table_quality_checks_summary() -> Response: + global QUALITY_CLIENT_INSTANCE + try: + client = get_quality_client() + if client is not None: + return _get_dq_checks_summary_client() + payload = jsonify({'checks': {}, 'msg': 'A client for retrieving quality checks must be configured'}) + return make_response(payload, HTTPStatus.NOT_IMPLEMENTED) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = jsonify({'checks': {}, 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + +def _get_dq_checks_summary_client() -> Response: + client = get_quality_client() + table_key = get_query_param(request.args, 'key') + response = client.get_table_quality_checks_summary(table_key=table_key) + status_code = response.status_code + if status_code == HTTPStatus.OK: + try: + quality_checks = json.loads(response.data).get('checks') + payload = jsonify({'checks': quality_checks, 'msg': 'Success'}) + except ValidationError as err: + logging.error('Quality data dump returned errors: ' + str(err.messages)) + raise Exception('The preview client did not return a valid Quality Checks object') + else: + message = 'Encountered error: Quality client request failed with code ' + str(status_code) + logging.error(message) + # only necessary to pass the error text + payload = jsonify({'checks': {}, 'msg': message}) + return make_response(payload, status_code) diff --git a/frontend/amundsen_application/api/search/__init__.py b/frontend/amundsen_application/api/search/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/search/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/search/v1.py b/frontend/amundsen_application/api/search/v1.py new file mode 100644 index 0000000000..a53ddb69a8 --- /dev/null +++ b/frontend/amundsen_application/api/search/v1.py @@ -0,0 +1,159 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from http import HTTPStatus +from typing import Any, Dict, List # noqa: F401 + +from amundsen_common.models.search import (Filter, SearchRequestSchema, + SearchResponseSchema) +from flask import Response +from flask import current_app as app +from flask import jsonify, make_response, request +from flask.blueprints import Blueprint + +from amundsen_application.log.action_log import action_logging +from amundsen_application.api.utils.request_utils import (get_query_param, + request_search) +from amundsen_application.api.utils.search_utils import ( + generate_query_request, map_dashboard_result, map_feature_result, + map_table_result, map_user_result) + +LOGGER = logging.getLogger(__name__) + +REQUEST_SESSION_TIMEOUT_SEC = 3 + +SEARCH_ENDPOINT = '/v2/search' + +RESOURCE_TO_MAPPING = { + 'table': map_table_result, + 'dashboard': map_dashboard_result, + 'feature': map_feature_result, + 'user': map_user_result, +} + +DEFAULT_FILTER_OPERATION = 'OR' + +search_blueprint = Blueprint('search', __name__, url_prefix='/api/search/v1') + + +def _transform_filters(filters: Dict, resources: List[str]) -> List[Filter]: + transformed_filters = [] + searched_resources_with_filters = set(filters.keys()).intersection(resources) + for resource in searched_resources_with_filters: + resource_filters = filters[resource] + for field in resource_filters.keys(): + field_filters = resource_filters[field] + values = [] + filter_operation = DEFAULT_FILTER_OPERATION + + if field_filters is not None and field_filters.get('value') is not None: + value_str = field_filters.get('value') + values = [str.strip() for str in value_str.split(',') if str != ''] + filter_operation = field_filters.get('filterOperation', DEFAULT_FILTER_OPERATION) + + transformed_filters.append(Filter(name=field, + values=values, + operation=filter_operation)) + + return transformed_filters + + +@search_blueprint.route('/search', methods=['POST']) +def search() -> Response: + """ + Parse the request arguments and call the helper method to execute a search for specified resources + :return: a Response created with the results from the helper method + """ + results_dict = {} + try: + request_json = request.get_json() + search_term = get_query_param(request_json, 'searchTerm', '"searchTerm" parameter expected in request data') + page_index = get_query_param(request_json, 'pageIndex', '"pageIndex" parameter expected in request data') + results_per_page = get_query_param(request_json, + 'resultsPerPage', + '"resultsPerPage" parameter expected in request data') + search_type = request_json.get('searchType') + resources = request_json.get('resources', []) + filters = request_json.get('filters', {}) + highlight_options = request_json.get('highlightingOptions', {}) + results_dict = _search_resources(search_term=search_term, + resources=resources, + page_index=int(page_index), + results_per_page=int(results_per_page), + filters=filters, + highlight_options=highlight_options, + search_type=search_type) + return make_response(jsonify(results_dict), results_dict.get('status_code', HTTPStatus.OK)) + except Exception as e: + message = 'Encountered exception: ' + str(e) + LOGGER.exception(message) + return make_response(jsonify(results_dict), HTTPStatus.INTERNAL_SERVER_ERROR) + + +@action_logging +def _search_resources(*, search_term: str, + resources: List[str], + page_index: int, + results_per_page: int, + filters: Dict, + highlight_options: Dict, + search_type: str) -> Dict[str, Any]: + """ + Call the search service endpoint and return matching results + :return: a json output containing search results array as 'results' + """ + default_results = { + 'page_index': int(page_index), + 'results': [], + 'total_results': 0, + } + + results_dict = { + 'search_term': search_term, + 'msg': '', + 'table': default_results, + 'dashboard': default_results, + 'feature': default_results, + 'user': default_results, + } + + try: + transformed_filters = _transform_filters(filters=filters, resources=resources) + query_request = generate_query_request(filters=transformed_filters, + resources=resources, + page_index=page_index, + results_per_page=results_per_page, + search_term=search_term, + highlight_options=highlight_options) + request_json = json.dumps(SearchRequestSchema().dump(query_request)) + url_base = app.config['SEARCHSERVICE_BASE'] + SEARCH_ENDPOINT + response = request_search(url=url_base, + headers={'Content-Type': 'application/json'}, + method='POST', + data=request_json) + status_code = response.status_code + + if status_code == HTTPStatus.OK: + search_response = SearchResponseSchema().loads(json.dumps(response.json())) + results_dict['msg'] = search_response.msg + results = search_response.results + for resource in results.keys(): + results_dict[resource] = { + 'page_index': int(page_index), + 'results': [RESOURCE_TO_MAPPING[resource](result) for result in results[resource]['results']], + 'total_results': results[resource]['total_results'], + } + else: + message = 'Encountered error: Search request failed' + results_dict['msg'] = message + + results_dict['status_code'] = status_code + return results_dict + + except Exception as e: + message = f'Encountered exception: {str(e)}' + results_dict['msg'] = message + LOGGER.exception(message) + return results_dict diff --git a/frontend/amundsen_application/api/utils/__init__.py b/frontend/amundsen_application/api/utils/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/api/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/api/utils/metadata_utils.py b/frontend/amundsen_application/api/utils/metadata_utils.py new file mode 100644 index 0000000000..0ab75db5a1 --- /dev/null +++ b/frontend/amundsen_application/api/utils/metadata_utils.py @@ -0,0 +1,293 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from dataclasses import dataclass +from marshmallow import EXCLUDE +from typing import Any, Dict, List, Optional + +from amundsen_common.models.dashboard import DashboardSummary, DashboardSummarySchema +from amundsen_common.models.feature import Feature, FeatureSchema +from amundsen_common.models.popular_table import PopularTable, PopularTableSchema +from amundsen_common.models.table import Table, TableSchema, TypeMetadata +from amundsen_application.models.user import load_user, dump_user +from amundsen_application.config import MatchRuleObject +from flask import current_app as app +import re + + +@dataclass +class TableUri: + database: str + cluster: str + schema: str + table: str + + def __str__(self) -> str: + return f"{self.database}://{self.cluster}.{self.schema}/{self.table}" + + @classmethod + def from_uri(cls, uri: str) -> 'TableUri': + """ + TABLE_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}' + """ + pattern = re.compile(r'^(?P.*?)://(?P.*)\.(?P.*?)/(?P.*?)$', re.X) + + groups = pattern.match(uri) + + spec = groups.groupdict() if groups else {} + + return TableUri(**spec) + + +def marshall_table_partial(table_dict: Dict) -> Dict: + """ + Forms a short version of a table Dict, with selected fields and an added 'key' + :param table_dict: Dict of partial table object + :return: partial table Dict + + TODO - Unify data format returned by search and metadata. + """ + schema = PopularTableSchema() + table: PopularTable = schema.load(table_dict, unknown=EXCLUDE) + results = schema.dump(table) + # TODO: fix popular tables to provide these? remove if we're not using them? + # TODO: Add the 'key' or 'id' to the base PopularTableSchema + results['key'] = f'{table.database}://{table.cluster}.{table.schema}/{table.name}' + results['last_updated_timestamp'] = None + results['type'] = 'table' + + return results + + +def _parse_editable_rule(rule: MatchRuleObject, + schema: str, + table: str) -> bool: + """ + Matches table name and schema with corresponding regex in matching rule + :parm rule: MatchRuleObject defined in list UNEDITABLE_TABLE_DESCRIPTION_MATCH_RULES in config file + :parm schema: schema name from Table Dict received from metadata service + :parm table: table name from Table Dict received from metadata service + :return: boolean which determines if table desc is editable or not for given table as per input matching rule + """ + if rule.schema_regex and rule.table_name_regex: + match_schema = re.match(rule.schema_regex, schema) + match_table = re.match(rule.table_name_regex, table) + return not (match_schema and match_table) + + if rule.schema_regex: + return not re.match(rule.schema_regex, schema) + + if rule.table_name_regex: + return not re.match(rule.table_name_regex, table) + + return True + + +def is_table_editable(schema_name: str, table_name: str, cfg: Any = None) -> bool: + if cfg is None: + cfg = app.config + + if cfg['ALL_UNEDITABLE_SCHEMAS']: + return False + + if schema_name in cfg['UNEDITABLE_SCHEMAS']: + return False + + for rule in cfg['UNEDITABLE_TABLE_DESCRIPTION_MATCH_RULES']: + if not _parse_editable_rule(rule, schema_name, table_name): + return False + + return True + + +def _recursive_set_type_metadata_is_editable(type_metadata: Optional[TypeMetadata], is_editable: bool) -> None: + if type_metadata is not None: + type_metadata['is_editable'] = is_editable + for tm in getattr(type_metadata, 'children', []): + _recursive_set_type_metadata_is_editable(tm, is_editable) + + +def marshall_table_full(table_dict: Dict) -> Dict: + """ + Forms the full version of a table Dict, with additional and sanitized fields + :param table_dict: Table Dict from metadata service + :return: Table Dict with sanitized fields + """ + + schema = TableSchema() + table: Table = schema.load(table_dict) + results: Dict[str, Any] = schema.dump(table) + + is_editable = is_table_editable(results['schema'], results['name']) + results['is_editable'] = is_editable + + # TODO - Cleanup https://github.com/lyft/amundsen/issues/296 + # This code will try to supplement some missing data since the data here is incomplete. + # Once the metadata service response provides complete user objects we can remove this. + results['owners'] = [_map_user_object_to_schema(owner) for owner in results['owners']] + readers = results['table_readers'] + for reader_object in readers: + reader_object['user'] = _map_user_object_to_schema(reader_object['user']) + + # TODO: Add the 'key' or 'id' to the base TableSchema + results['key'] = f'{table.database}://{table.cluster}.{table.schema}/{table.name}' + # Temp code to make 'partition_key' and 'partition_value' part of the table + results['partition'] = _get_partition_data(results['watermarks']) + + # We follow same style as column stat order for arranging the programmatic descriptions + prog_descriptions = results['programmatic_descriptions'] + results['programmatic_descriptions'] = _convert_prog_descriptions(prog_descriptions) + + columns = results['columns'] + for col in columns: + # Set column key to guarantee it is available on the frontend + # since it is currently an optional field in the model + col['key'] = results['key'] + '/' + col['name'] + # Set editable state + col['is_editable'] = is_editable + _recursive_set_type_metadata_is_editable(col['type_metadata'], is_editable) + # If order is provided, we sort the column based on the pre-defined order + if app.config['COLUMN_STAT_ORDER']: + # the stat_type isn't defined in COLUMN_STAT_ORDER, we just use the max index for sorting + col['stats'].sort(key=lambda x: app.config['COLUMN_STAT_ORDER']. + get(x['stat_type'], len(app.config['COLUMN_STAT_ORDER']))) + + return results + + +def marshall_dashboard_partial(dashboard_dict: Dict) -> Dict: + """ + Forms a short version of dashboard metadata, with selected fields and an added 'key' + and 'type' + :param dashboard_dict: Dict of partial dashboard metadata + :return: partial dashboard Dict + """ + schema = DashboardSummarySchema(unknown=EXCLUDE) + dashboard: DashboardSummary = schema.load(dashboard_dict) + results = schema.dump(dashboard) + results['type'] = 'dashboard' + # TODO: Bookmark logic relies on key, opting to add this here to avoid messy logic in + # React app and we have to clean up later. + results['key'] = results.get('uri', '') + return results + + +def marshall_dashboard_full(dashboard_dict: Dict) -> Dict: + """ + Cleanup some fields in the dashboard response + :param dashboard_dict: Dashboard response from metadata service. + :return: Dashboard dictionary with sanitized fields, particularly the tables and owners. + """ + # TODO - Cleanup https://github.com/lyft/amundsen/issues/296 + # This code will try to supplement some missing data since the data here is incomplete. + # Once the metadata service response provides complete user objects we can remove this. + dashboard_dict['owners'] = [_map_user_object_to_schema(owner) for owner in dashboard_dict['owners']] + dashboard_dict['tables'] = [marshall_table_partial(table) for table in dashboard_dict['tables']] + return dashboard_dict + + +def marshall_lineage_table(table_dict: Dict) -> Dict: + """ + Decorate lineage entries with database, schema, cluster, and table + :param table_dict: + :return: table entry with additional fields + """ + table_key = str(table_dict.get('key')) + table_uri = TableUri.from_uri(table_key) + table_dict['database'] = table_uri.database + table_dict['schema'] = table_uri.schema + table_dict['cluster'] = table_uri.cluster + table_dict['name'] = table_uri.table + return table_dict + + +def _convert_prog_descriptions(prog_descriptions: Optional[List] = None) -> Dict: + """ + Apply the PROGRAMMATIC_DISPLAY configuration to convert to the structure. + :param prog_descriptions: A list of objects representing programmatic descriptions + :return: A dictionary with organized programmatic_descriptions + """ + left = [] # type: List + right = [] # type: List + other = prog_descriptions or [] # type: List + updated_descriptions = {} + + if prog_descriptions: + # We want to make sure there is a display title that is just source + for desc in prog_descriptions: + source = desc.get('source') + if not source: + logging.warning("no source found in: " + str(desc)) + + # If config is defined for programmatic disply we organize and sort them based on the configuration + prog_display_config = app.config['PROGRAMMATIC_DISPLAY'] + if prog_display_config: + left_config = prog_display_config.get('LEFT', {}) + left = [x for x in prog_descriptions if x.get('source') in left_config] + left.sort(key=lambda x: _sort_prog_descriptions(left_config, x)) + + right_config = prog_display_config.get('RIGHT', {}) + right = [x for x in prog_descriptions if x.get('source') in right_config] + right.sort(key=lambda x: _sort_prog_descriptions(right_config, x)) + + other_config = dict(filter(lambda x: x not in ['LEFT', 'RIGHT'], prog_display_config.items())) + other = list(filter(lambda x: x.get('source') not in left_config and x.get('source') + not in right_config, prog_descriptions)) + other.sort(key=lambda x: _sort_prog_descriptions(other_config, x)) + + updated_descriptions['left'] = left + updated_descriptions['right'] = right + updated_descriptions['other'] = other + return updated_descriptions + + +def _sort_prog_descriptions(base_config: Dict, prog_description: Dict) -> int: + default_order = len(base_config) + prog_description_source = prog_description.get('source') + config_dict = base_config.get(prog_description_source) + if config_dict: + return config_dict.get('display_order', default_order) + return default_order + + +def _map_user_object_to_schema(u: Dict) -> Dict: + return dump_user(load_user(u)) + + +def _get_partition_data(watermarks: Dict) -> Dict: + if watermarks: + high_watermark = next(filter(lambda x: x['watermark_type'] == 'high_watermark', watermarks)) + if high_watermark: + return { + 'is_partitioned': True, + 'key': high_watermark['partition_key'], + 'value': high_watermark['partition_value'] + } + return { + 'is_partitioned': False + } + + +def marshall_feature_full(feature_dict: Dict) -> Dict: + """ + Forms the full version of a table Dict, with additional and sanitized fields + :param table_dict: Table Dict from metadata service + :return: Table Dict with sanitized fields + """ + + schema = FeatureSchema() + feature: Feature = schema.load(feature_dict) + results: Dict[str, Any] = schema.dump(feature) + + # TODO do we need this for Features? + # is_editable = is_table_editable(results['schema'], results['name']) + # results['is_editable'] = is_editable + + results['owners'] = [_map_user_object_to_schema(owner) for owner in results['owners']] + + prog_descriptions = results['programmatic_descriptions'] + results['programmatic_descriptions'] = _convert_prog_descriptions(prog_descriptions) + + return results diff --git a/frontend/amundsen_application/api/utils/notification_utils.py b/frontend/amundsen_application/api/utils/notification_utils.py new file mode 100644 index 0000000000..877f97a805 --- /dev/null +++ b/frontend/amundsen_application/api/utils/notification_utils.py @@ -0,0 +1,235 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from http import HTTPStatus +from enum import Enum + +from flask import current_app as app +from flask import jsonify, make_response, Response +from typing import Dict, List + +from amundsen_application.api.exceptions import MailClientNotImplemented +from amundsen_application.log.action_log import action_logging + + +class NotificationType(str, Enum): + """ + Enum to describe supported notification types. Must match NotificationType interface defined in: + https://github.com/lyft/amundsenfrontendlibrary/blob/master/amundsen_application/static/js/interfaces/Notifications.ts + """ + OWNER_ADDED = 'owner_added' + OWNER_REMOVED = 'owner_removed' + METADATA_EDITED = 'metadata_edited' + METADATA_REQUESTED = 'metadata_requested' + DATA_ISSUE_REPORTED = 'data_issue_reported' + + @classmethod + def has_value(cls, value: str) -> bool: + for key in cls: + if key.value == value: + return True + return False + + +NOTIFICATION_STRINGS = { + NotificationType.OWNER_ADDED.value: { + 'comment': ('
What is expected of you?
As an owner, you take an important part in making ' + 'sure that the datasets you own can be used as swiftly as possible across the company.
' + 'Make sure the metadata is correct and up to date.
'), + 'end_note': ('
If you think you are not the best person to own this dataset and know someone who might ' + 'be, please contact this person and ask them if they want to replace you. It is important that we ' + 'keep multiple owners for each dataset to ensure continuity.
'), + 'notification': ('
You have been added to the owners list of the ' + '{resource_name} dataset by {sender}.
'), + }, + NotificationType.OWNER_REMOVED.value: { + 'comment': '', + 'end_note': ('
If you think you have been incorrectly removed as an owner, ' + 'add yourself back to the owners list.
'), + 'notification': ('
You have been removed from the owners list of the ' + '{resource_name} dataset by {sender}.
'), + }, + NotificationType.METADATA_REQUESTED.value: { + 'comment': '', + 'end_note': '
Please visit the provided link and improve descriptions on that resource.
', + 'notification': '
{sender} is trying to use {resource_name}, ', + }, + NotificationType.DATA_ISSUE_REPORTED.value: { + 'comment': '
Link to the issue: {data_issue_url}
', + 'end_note': '
Please visit the provided issue link for more information. You are getting this email ' + 'because you are listed as an owner of the resource. Please do not reply to this email.
', + 'notification': '
{sender} has reported a data issue for {resource_name}, ', + } +} + + +def get_mail_client(): # type: ignore + """ + Gets a mail_client object to send emails, raises an exception + if mail client isn't implemented + """ + mail_client = app.config['MAIL_CLIENT'] + + if not mail_client: + raise MailClientNotImplemented('An instance of BaseMailClient client must be configured on MAIL_CLIENT') + + return mail_client + + +def validate_options(*, options: Dict) -> None: + """ + Raises an Exception if the options do not contain resource_path or resource_name + """ + if options.get('resource_path') is None: + raise Exception('resource_path was not provided in the notification options') + if options.get('resource_name') is None: + raise Exception('resource_name was not provided in the notification options') + + +def get_notification_html(*, notification_type: str, options: Dict, sender: str) -> str: + """ + Returns the formatted html for the notification based on the notification_type + :return: A string representing the html markup to send in the notification + """ + validate_options(options=options) + + url_base = app.config['FRONTEND_BASE'] + resource_url = '{url_base}{resource_path}?source=notification'.format(resource_path=options.get('resource_path'), + url_base=url_base) + joined_chars = resource_url[len(url_base) - 1:len(url_base) + 1] + if joined_chars.count('/') != 1: + raise Exception('Configured "FRONTEND_BASE" and "resource_path" do not form a valid url') + + notification_strings = NOTIFICATION_STRINGS.get(notification_type) + if notification_strings is None: + raise Exception('Unsupported notification_type') + + greeting = 'Hello,
' + notification = notification_strings.get('notification', '').format(resource_url=resource_url, + resource_name=options.get('resource_name'), + sender=sender) + comment = notification_strings.get('comment', '') + end_note = notification_strings.get('end_note', '') + salutation = '
Thanks,
Amundsen Team' + + if notification_type == NotificationType.METADATA_REQUESTED: + options_comment = options.get('comment') + need_resource_description = options.get('description_requested') + need_fields_descriptions = options.get('fields_requested') + + if need_resource_description and need_fields_descriptions: + notification = notification + 'and requests improved table and column descriptions.
' + elif need_resource_description: + notification = notification + 'and requests an improved table description.
' + elif need_fields_descriptions: + notification = notification + 'and requests improved column descriptions.
' + else: + notification = notification + 'and requests more information about that resource.
' + + if options_comment: + comment = ('
{sender} has included the following information with their request:' + '
{comment}
').format(sender=sender, comment=options_comment) + + if notification_type == NotificationType.DATA_ISSUE_REPORTED: + greeting = 'Hello data owner,
' + data_issue_url = options.get('data_issue_url') + comment = comment.format(data_issue_url=data_issue_url) + + return '{greeting}{notification}{comment}{end_note}{salutation}'.format(greeting=greeting, + notification=notification, + comment=comment, + end_note=end_note, + salutation=salutation) + + +def get_notification_subject(*, notification_type: str, options: Dict) -> str: + """ + Returns the subject to use for the given notification_type + :param notification_type: type of notification + :param options: data necessary to render email template content + :return: The subject to be used with the notification + """ + resource_name = options.get('resource_name') + notification_subject_dict = { + NotificationType.OWNER_ADDED.value: 'You are now an owner of {}'.format(resource_name), + NotificationType.OWNER_REMOVED.value: 'You have been removed as an owner of {}'.format(resource_name), + NotificationType.METADATA_EDITED.value: 'Your dataset {}\'s metadata has been edited'.format(resource_name), + NotificationType.METADATA_REQUESTED.value: 'Request for metadata on {}'.format(resource_name), + NotificationType.DATA_ISSUE_REPORTED.value: 'A data issue has been reported for {}'.format(resource_name) + } + subject = notification_subject_dict.get(notification_type) + if subject is None: + raise Exception('Unsupported notification_type') + return subject + + +def send_notification(*, notification_type: str, options: Dict, recipients: List, sender: str) -> Response: + """ + Sends a notification via email to a given list of recipients + :param notification_type: type of notification + :param options: data necessary to render email template content + :param recipients: list of recipients who should receive notification + :param sender: email of notification sender + :return: Response + """ + @action_logging + def _log_send_notification(*, notification_type: str, options: Dict, recipients: List, sender: str) -> None: + """ Logs the content of a sent notification""" + pass # pragma: no cover + + try: + if not app.config['NOTIFICATIONS_ENABLED']: + message = 'Notifications are not enabled. Request was accepted but no notification will be sent.' + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.ACCEPTED) + if sender in recipients: + recipients.remove(sender) + if len(recipients) == 0: + logging.info('No recipients exist for notification') + return make_response( + jsonify({ + 'msg': 'No valid recipients exist for notification, notification was not sent.' + }), + HTTPStatus.OK + ) + + mail_client = get_mail_client() + + html = get_notification_html(notification_type=notification_type, options=options, sender=sender) + subject = get_notification_subject(notification_type=notification_type, options=options) + + _log_send_notification( + notification_type=notification_type, + options=options, + recipients=recipients, + sender=sender + ) + + response = mail_client.send_email( + html=html, + subject=subject, + optional_data={ + 'email_type': notification_type, + }, + recipients=recipients, + sender=sender, + ) + status_code = response.status_code + + if 200 <= status_code < 300: + message = 'Success' + else: + message = 'Mail client failed with status code ' + str(status_code) + logging.error(message) + + return make_response(jsonify({'msg': message}), status_code) + except MailClientNotImplemented as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.NOT_IMPLEMENTED) + except Exception as e1: + message = 'Encountered exception: ' + str(e1) + logging.exception(message) + return make_response(jsonify({'msg': message}), HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/api/utils/request_utils.py b/frontend/amundsen_application/api/utils/request_utils.py new file mode 100644 index 0000000000..35999eb5a5 --- /dev/null +++ b/frontend/amundsen_application/api/utils/request_utils.py @@ -0,0 +1,133 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Optional + +import requests +from flask import current_app as app + + +def get_query_param(args: Dict, param: str, error_msg: Optional[str] = None) -> str: + value = args.get(param) + if value is None: + msg = 'A {0} parameter must be provided'.format(param) if error_msg is None else error_msg + raise Exception(msg) + return value + + +def request_metadata(*, # type: ignore + url: str, + method: str = 'GET', + headers=None, + timeout_sec: int = 0, + data=None, + json=None): + """ + Helper function to make a request to metadata service. + Sets the client and header information based on the configuration + :param headers: Optional headers for the request, e.g. specifying Content-Type + :param method: DELETE | GET | POST | PUT + :param url: The request URL + :param timeout_sec: Number of seconds before timeout is triggered. + :param data: Optional request payload + :return: + """ + if headers is None: + headers = {} + + if app.config['REQUEST_HEADERS_METHOD']: + headers.update(app.config['REQUEST_HEADERS_METHOD'](app)) + elif app.config['METADATASERVICE_REQUEST_HEADERS']: + headers.update(app.config['METADATASERVICE_REQUEST_HEADERS']) + return request_wrapper(method=method, + url=url, + client=app.config['METADATASERVICE_REQUEST_CLIENT'], + headers=headers, + timeout_sec=timeout_sec, + data=data, + json=json) + + +def request_search(*, # type: ignore + url: str, + method: str = 'GET', + headers=None, + timeout_sec: int = 0, + data=None, + json=None): + """ + Helper function to make a request to search service. + Sets the client and header information based on the configuration + :param headers: Optional headers for the request, e.g. specifying Content-Type + :param method: DELETE | GET | POST | PUT + :param url: The request URL + :param timeout_sec: Number of seconds before timeout is triggered. + :param data: Optional request payload + :return: + """ + if headers is None: + headers = {} + + if app.config['REQUEST_HEADERS_METHOD']: + headers.update(app.config['REQUEST_HEADERS_METHOD'](app)) + elif app.config['SEARCHSERVICE_REQUEST_HEADERS']: + headers.update(app.config['SEARCHSERVICE_REQUEST_HEADERS']) + + return request_wrapper(method=method, + url=url, + client=app.config['SEARCHSERVICE_REQUEST_CLIENT'], + headers=headers, + timeout_sec=timeout_sec, + data=data, + json=json) + + +# TODO: Define an interface for envoy_client +def request_wrapper(method: str, url: str, client, headers, timeout_sec: int, data=None, json=None): # type: ignore + """ + Wraps a request to use Envoy client and headers, if available + :param method: DELETE | GET | POST | PUT + :param url: The request URL + :param client: Optional Envoy client + :param headers: Optional Envoy request headers + :param timeout_sec: Number of seconds before timeout is triggered. Not used with Envoy + :param data: Optional request payload + :return: + """ + # If no timeout specified, use the one from the configurations. + timeout_sec = timeout_sec or app.config['REQUEST_SESSION_TIMEOUT_SEC'] + + if client is not None: + if method == 'DELETE': + return client.delete(url, headers=headers, raw_response=True, data=data, json=json) + elif method == 'GET': + return client.get(url, headers=headers, raw_response=True) + elif method == 'POST': + return client.post(url, headers=headers, raw_response=True, raw_request=True, data=data, json=json) + elif method == 'PUT': + return client.put(url, headers=headers, raw_response=True, raw_request=True, data=data, json=json) + else: + raise Exception('Method not allowed: {}'.format(method)) + else: + with build_session() as s: + if method == 'DELETE': + return s.delete(url, headers=headers, timeout=timeout_sec, data=data, json=json) + elif method == 'GET': + return s.get(url, headers=headers, timeout=timeout_sec) + elif method == 'POST': + return s.post(url, headers=headers, timeout=timeout_sec, data=data, json=json) + elif method == 'PUT': + return s.put(url, headers=headers, timeout=timeout_sec, data=data, json=json) + else: + raise Exception('Method not allowed: {}'.format(method)) + + +def build_session() -> requests.Session: + session = requests.Session() + + cert = app.config.get('MTLS_CLIENT_CERT') + key = app.config.get('MTLS_CLIENT_KEY') + if cert is not None and key is not None: + session.cert = (cert, key) + + return session diff --git a/frontend/amundsen_application/api/utils/response_utils.py b/frontend/amundsen_application/api/utils/response_utils.py new file mode 100644 index 0000000000..6b3a85f31b --- /dev/null +++ b/frontend/amundsen_application/api/utils/response_utils.py @@ -0,0 +1,19 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from typing import Dict # noqa: F401 + +from flask import Response, jsonify, make_response + + +def create_error_response(*, message: str, payload: Dict, status_code: int) -> Response: + """ + Logs and info level log with the given message, and returns a response with: + 1. The given message as 'msg' in the response data + 2. The given status code as thge response status code + """ + logging.info(message) + payload['msg'] = message + return make_response(jsonify(payload), status_code) diff --git a/frontend/amundsen_application/api/utils/search_utils.py b/frontend/amundsen_application/api/utils/search_utils.py new file mode 100644 index 0000000000..f582a57b28 --- /dev/null +++ b/frontend/amundsen_application/api/utils/search_utils.py @@ -0,0 +1,148 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Dict, List # noqa: F401 + +from http import HTTPStatus + +from flask import current_app as app + +from amundsen_application.api.utils.request_utils import request_search + +from amundsen_common.models.search import Filter, SearchRequest + +from amundsen_application.models.user import dump_user, load_user + +LOGGER = logging.getLogger(__name__) + +# These can move to a configuration when we have custom use cases outside of these default values +valid_search_fields = { + 'table': { + 'badges', + 'column', + 'database', + 'schema', + 'table', + 'tag' + }, + 'dashboard': { + 'group_name', + 'name', + 'product', + 'tag' + }, + 'feature': { + 'badges', + 'entity', + 'feature_name', + 'feature_group', + 'tags' + } +} + + +def map_dashboard_result(result: Dict) -> Dict: + return { + 'type': 'dashboard', + 'key': result.get('key', None), + 'uri': result.get('uri', None), + 'url': result.get('url', None), + 'group_name': result.get('group_name', None), + 'name': result.get('name', None), + 'product': result.get('product', None), + 'tag': result.get('tag', None), + 'description': result.get('description', None), + 'last_successful_run_timestamp': result.get('last_successful_run_timestamp', None), + 'highlight': result.get('highlight', {}), + } + + +def map_table_result(result: Dict) -> Dict: + name = result.get('name') if result.get('name') else result.get('table') + return { + 'type': 'table', + 'key': result.get('key', None), + 'name': name, + 'cluster': result.get('cluster', None), + 'description': result.get('description', None), + 'database': result.get('database', None), + 'schema': result.get('schema', None), + 'schema_description': result.get('schema_description', None), + 'badges': result.get('badges', None), + 'last_updated_timestamp': result.get('last_updated_timestamp', None), + 'highlight': result.get('highlight', None), + } + + +def map_feature_result(result: Dict) -> Dict: + return { + 'type': 'feature', + 'description': result.get('description', None), + 'key': result.get('key', None), + 'last_updated_timestamp': result.get('last_updated_timestamp', None), + 'name': result.get('feature_name', None), + 'feature_group': result.get('feature_group', None), + 'version': result.get('version', None), + 'availability': result.get('availability', None), + 'entity': result.get('entity', None), + 'badges': result.get('badges', None), + 'status': result.get('status', None), + 'highlight': result.get('highlight', {}), + } + + +def map_user_result(result: Dict) -> Dict: + user_result = dump_user(load_user(result)) + user_result['type'] = 'user' + user_result['highlight'] = result.get('highlight', {}) + return user_result + + +def generate_query_json(*, filters: Dict = {}, page_index: int, search_term: str) -> Dict: + """ + Transforms the given paramaters to the query json for the search service according to + the api defined at: + https://github.com/lyft/amundsensearchlibrary/blob/master/search_service/api/swagger_doc/table/search_table_filter.yml + https://github.com/lyft/amundsensearchlibrary/blob/master/search_service/api/swagger_doc/dashboard/search_dashboard_filter.yml + """ + + return { + 'page_index': int(page_index), + 'search_request': { + 'type': 'AND', + 'filters': filters + }, + 'query_term': search_term + } + + +def execute_search_document_request(request_json: str, method: str) -> int: + search_service_base = app.config['SEARCHSERVICE_BASE'] + search_document_url = f'{search_service_base}/v2/document' + update_response = request_search( + url=search_document_url, + method=method, + headers={'Content-Type': 'application/json'}, + data=request_json, + ) + status_code = update_response.status_code + if status_code != HTTPStatus.OK: + LOGGER.info(f'Failed to execute {method} for {request_json} in searchservice, status code: {status_code}') + LOGGER.info(update_response.text) + + return status_code + + +def generate_query_request(*, filters: List[Filter] = [], + resources: List[str] = [], + page_index: int = 0, + results_per_page: int = 10, + search_term: str, + highlight_options: Dict) -> SearchRequest: + return SearchRequest(query_term=search_term, + resource_types=resources, + page_index=page_index, + results_per_page=results_per_page, + filters=filters, + highlight_options=highlight_options) diff --git a/frontend/amundsen_application/api/v0.py b/frontend/amundsen_application/api/v0.py new file mode 100644 index 0000000000..b106e1fcf8 --- /dev/null +++ b/frontend/amundsen_application/api/v0.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from http import HTTPStatus + +from flask import Response, jsonify, make_response +from flask import current_app as app +from flask.blueprints import Blueprint + +from amundsen_application.api.metadata.v0 import USER_ENDPOINT +from amundsen_application.api.utils.request_utils import request_metadata +from amundsen_application.models.user import load_user, dump_user + + +LOGGER = logging.getLogger(__name__) + +blueprint = Blueprint('main', __name__, url_prefix='/api') + + +@blueprint.route('/auth_user', methods=['GET']) +def current_user() -> Response: + try: + if app.config['AUTH_USER_METHOD']: + user = app.config['AUTH_USER_METHOD'](app) + else: + raise Exception('AUTH_USER_METHOD is not configured') + + url = '{0}{1}/{2}'.format(app.config['METADATASERVICE_BASE'], USER_ENDPOINT, user.user_id) + + response = request_metadata(url=url) + status_code = response.status_code + if status_code == HTTPStatus.OK: + message = 'Success' + else: + message = 'Encountered error: failed to fetch user with user_id: {0}'.format(user.user_id) + logging.error(message) + + payload = { + 'msg': message, + 'user': dump_user(load_user(response.json())) + } + return make_response(jsonify(payload), status_code) + except Exception as e: + message = 'Encountered exception: ' + str(e) + logging.exception(message) + payload = {'msg': message} + return make_response(jsonify(payload), HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/base/__init__.py b/frontend/amundsen_application/base/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/base/base_announcement_client.py b/frontend/amundsen_application/base/base_announcement_client.py new file mode 100644 index 0000000000..a82e23668e --- /dev/null +++ b/frontend/amundsen_application/base/base_announcement_client.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging + +from http import HTTPStatus + +from flask import jsonify, make_response, Response +from marshmallow import ValidationError + +from amundsen_application.models.announcements import Announcements, AnnouncementsSchema + + +class BaseAnnouncementClient(abc.ABC): + @abc.abstractmethod + def __init__(self) -> None: + pass # pragma: no cover + + @abc.abstractmethod + def get_posts(self) -> Announcements: + """ + Returns an instance of amundsen_application.models.announcements.Announcements, which should match + amundsen_application.models.announcements.AnnouncementsSchema + """ + pass # pragma: no cover + + def _get_posts(self) -> Response: + def _create_error_response(message: str) -> Response: + logging.exception(message) + payload = jsonify({'posts': [], 'msg': message}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + try: + announcements = self.get_posts() + except Exception as e: + message = 'Encountered exception getting posts: ' + str(e) + return _create_error_response(message) + + try: + data = AnnouncementsSchema().dump(announcements) + AnnouncementsSchema().load(data) # validate returned object + payload = jsonify({'posts': data.get('posts'), 'msg': 'Success'}) + return make_response(payload, HTTPStatus.OK) + except ValidationError as err: + message = 'Announcement data dump returned errors: ' + str(err.messages) + return _create_error_response(message) diff --git a/frontend/amundsen_application/base/base_bigquery_preview_client.py b/frontend/amundsen_application/base/base_bigquery_preview_client.py new file mode 100644 index 0000000000..978f3f9729 --- /dev/null +++ b/frontend/amundsen_application/base/base_bigquery_preview_client.py @@ -0,0 +1,95 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from http import HTTPStatus +import abc +import logging +from typing import Any, Dict, List, Optional +from amundsen_application.base.base_preview_client import BasePreviewClient +from amundsen_application.models.preview_data import ( + ColumnItem, + PreviewData, + PreviewDataSchema, +) +from flask import Response, make_response, jsonify +from marshmallow import ValidationError +from google.cloud import bigquery + +import json +import decimal + + +class BaseBigqueryPreviewClient(BasePreviewClient): + """ + Returns a Response object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + + def __init__(self, + bq_client: bigquery.Client, + preview_limit: int = 5, + previewable_projects: Optional[List] = None) -> None: + # Client passed from custom implementation. See example implementation. + self.bq_client = bq_client + self.preview_limit = preview_limit + # List of projects that are approved for whitelisting. None(Default) approves all google projects. + self.previewable_projects = previewable_projects + + @abc.abstractmethod + def _bq_list_rows( + self, gcp_project_id: str, table_project_name: str, table_name: str + ) -> PreviewData: + """ + Returns PreviewData from bigquery list rows api. + """ + pass # pragma: no cover + + def _column_item_from_bq_schema(self, schemafield: bigquery.SchemaField, key: Optional[str] = None) -> List: + """ + Recursively build ColumnItems from the bigquery schema + """ + all_fields = [] + if schemafield.field_type != "RECORD": + name = key + "." + schemafield.name if key else schemafield.name + return [ColumnItem(name, schemafield.field_type)] + for field in schemafield.fields: + if key: + name = key + "." + schemafield.name + else: + name = schemafield.name + all_fields.extend(self._column_item_from_bq_schema(field, name)) + return all_fields + + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> Response: + if self.previewable_projects and params["cluster"] not in self.previewable_projects: + return make_response(jsonify({"preview_data": {}}), HTTPStatus.FORBIDDEN) + + preview_data = self._bq_list_rows( + params["cluster"], + params["schema"], + params["tableName"], + ) + try: + data = PreviewDataSchema().dump(preview_data) + PreviewDataSchema().load(data) # for validation only + payload = json.dumps({"preview_data": data}, cls=Encoder) + return make_response(payload, HTTPStatus.OK) + except ValidationError as err: + logging.error("PreviewDataSchema serialization error + " + str(err.messages)) + return make_response( + jsonify({"preview_data": {}}), HTTPStatus.INTERNAL_SERVER_ERROR + ) + + def get_feature_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> Response: + pass + + +class Encoder(json.JSONEncoder): + """ + Customized json encoder class to address the parsing of decimal/numeric data types into float. + """ + + def default(self, obj: Any) -> Any: + if isinstance(obj, decimal.Decimal): + return float(obj) diff --git a/frontend/amundsen_application/base/base_issue_tracker_client.py b/frontend/amundsen_application/base/base_issue_tracker_client.py new file mode 100644 index 0000000000..27555551d5 --- /dev/null +++ b/frontend/amundsen_application/base/base_issue_tracker_client.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Any + +from amundsen_application.models.data_issue import DataIssue +from amundsen_application.models.issue_results import IssueResults + + +class BaseIssueTrackerClient(abc.ABC): + @abc.abstractmethod + def __init__(self) -> None: + pass # pragma: no cover + + @abc.abstractmethod + def get_issues(self, table_uri: str) -> IssueResults: + """ + Gets issues from the issue tracker + :param table_uri: Table Uri ie databasetype://database/table + :return: + """ + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + def create_issue(self, + table_uri: str, + title: str, + description: str, + priority_level: str, + table_url: str, + **kwargs: Any) -> DataIssue: + """ + Given a title, description, and table key, creates a ticket in the configured project + Automatically places the table_uri in the description of the ticket. + Returns the ticket information, including URL. + :param description: User provided description for the jira ticket + :param priority_level: Priority level for the ticket + :param table_uri: Table URI ie databasetype://database/table + :param title: Title of the ticket + :param table_url: Link to access the table + :return: A single ticket + """ + raise NotImplementedError # pragma: no cover diff --git a/frontend/amundsen_application/base/base_mail_client.py b/frontend/amundsen_application/base/base_mail_client.py new file mode 100644 index 0000000000..7cd4eedffe --- /dev/null +++ b/frontend/amundsen_application/base/base_mail_client.py @@ -0,0 +1,34 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Dict, List, Optional + +from flask import Response + + +class BaseMailClient(abc.ABC): + @abc.abstractmethod + def __init__(self, recipients: List[str]) -> None: + pass # pragma: no cover + + @abc.abstractmethod + def send_email(self, + html: str, + subject: str, + optional_data: Optional[Dict] = None, + recipients: Optional[List[str]] = None, + sender: Optional[str] = None) -> Response: + """ + Sends an email using the following parameters + :param html: HTML email content + :param subject: The subject of the email + :param optional_data: An optional dictionary of any values needed for custom implementations + :param recipients: An optional list of recipients for the email, the implementation + for this class should determine whether to use the recipients from the function, + the __init__ or both + :param sender: An optional sending address associated with the email, the implementation + should determine whether to use this value or another (e.g. from envvars) + :return: + """ + raise NotImplementedError # pragma: no cover diff --git a/frontend/amundsen_application/base/base_notice_client.py b/frontend/amundsen_application/base/base_notice_client.py new file mode 100644 index 0000000000..3fbb64a2d1 --- /dev/null +++ b/frontend/amundsen_application/base/base_notice_client.py @@ -0,0 +1,21 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from abc import abstractmethod +from flask import Response + + +class BaseNoticeClient(ABC): + """ + Abstract interface for a client that provides alerts affecting a given table. + """ + + @abstractmethod + def get_table_notices_summary(self, *, table_key: str) -> Response: + """ + Returns table alerts response for a given table URI + :param table_key: Table key for table to get alerts for + :return: flask Response object + """ + raise NotImplementedError diff --git a/frontend/amundsen_application/base/base_preview.py b/frontend/amundsen_application/base/base_preview.py new file mode 100644 index 0000000000..4b1f10f3c2 --- /dev/null +++ b/frontend/amundsen_application/base/base_preview.py @@ -0,0 +1,20 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABCMeta, abstractmethod + + +class BasePreview(metaclass=ABCMeta): + """ + A Preview interface for other product to implement. For example, see ModePreview. + """ + + @abstractmethod + def get_preview_image(self, *, uri: str) -> bytes: + """ + Returns image bytes given URI + :param uri: + :return: + :raises: FileNotFound when either Report is not available or Preview image is not available + """ + pass diff --git a/frontend/amundsen_application/base/base_preview_client.py b/frontend/amundsen_application/base/base_preview_client.py new file mode 100644 index 0000000000..3f8a0b4137 --- /dev/null +++ b/frontend/amundsen_application/base/base_preview_client.py @@ -0,0 +1,31 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from typing import Dict, Optional + +from flask import Response + + +class BasePreviewClient(abc.ABC): + @abc.abstractmethod + def __init__(self) -> None: + pass # pragma: no cover + + @abc.abstractmethod + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> Response: + """ + Returns a Response object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + def get_feature_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> Response: + """ + Returns a Response object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + raise NotImplementedError # pragma: no cover diff --git a/frontend/amundsen_application/base/base_quality_client.py b/frontend/amundsen_application/base/base_quality_client.py new file mode 100644 index 0000000000..377c3cd2d2 --- /dev/null +++ b/frontend/amundsen_application/base/base_quality_client.py @@ -0,0 +1,29 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABCMeta, abstractmethod +from flask import Response + + +class BaseQualityClient(metaclass=ABCMeta): + """ + An abstract interface for a Quality Service Client + """ + + @abstractmethod + def get_table_quality_checks_summary(self, *, table_key: str) -> Response: + """ + Returns table quality checks for a given table uri + :param table_key: Table key for the table whose table quality + :return: TableQualityChecks object + """ + raise NotImplementedError # pragma: no cover + + @abstractmethod + def get_table_quality_checks(self, *, table_key: str) -> bytes: + """ + Returns table quality checks for a given table uri + :param table_key: Table key for the table whose table quality + :return: TableQualityChecks object + """ + raise NotImplementedError # pragma: no cover diff --git a/frontend/amundsen_application/base/base_redash_preview_client.py b/frontend/amundsen_application/base/base_redash_preview_client.py new file mode 100644 index 0000000000..5da8a893a3 --- /dev/null +++ b/frontend/amundsen_application/base/base_redash_preview_client.py @@ -0,0 +1,274 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +from enum import Enum +import logging +import requests as r +import time + +from flask import Response as FlaskResponse, make_response, jsonify +from http import HTTPStatus +from typing import Any, Dict, Optional, Tuple + +from amundsen_application.base.base_preview_client import BasePreviewClient +from amundsen_application.models.preview_data import ColumnItem, PreviewData, PreviewDataSchema + + +LOGGER = logging.getLogger(__name__) + + +REDASH_SUBMIT_QUERY_ENDPOINT = '{redash_host}/api/queries/{query_id}/results' +REDASH_TRACK_JOB_ENDPOINT = '{redash_host}/api/jobs/{job_id}' +REDASH_QUERY_RESULTS_ENDPOINT = '{redash_host}/api/query_results/{query_result_id}' + + +class RedashApiKeyNotProvidedException(Exception): + pass + + +class RedashQueryCouldNotCompleteException(Exception): + pass + + +class RedashQueryTemplateDoesNotExistForResource(Exception): + pass + + +class RedashApiResponse(Enum): + PENDING = 1 # (waiting to be executed) + STARTED = 2 # (executing) + SUCCESS = 3 + FAILURE = 4 + CANCELLED = 5 + + +class BaseRedashPreviewClient(BasePreviewClient): + """ + Generic client for using Redash as a preview client backend. + + Redash does not allow arbitrary queries to be submitted but it does allow + the creation of templated queries that can be saved and referenced. Amundsen + uses these templated queries to pass in arguments such as the schema name + and table name in order to dynamically build a query on the fly. + + The suggested format of the query template is: + + select {{ SELECT_FIELDS }} + from {{ SCHEMA_NAME }}.{{ TABLE_NAME }} + {{ WHERE_CLAUSE }} + limit {{ RCD_LIMIT }} + + You will need to use the params (e.g. database, cluster, schema and table names) + to idenfiy the specific query ID in Redash to use. This is done via the + `get_redash_query_id` method. + + The template values in the Redash query will be filled by the `build_redash_query_params` + function. + """ + + def __init__(self, redash_host: str, user_api_key: Optional[str] = None) -> None: + self.redash_host = redash_host + self.user_api_key: Optional[str] = user_api_key + self.headers: Optional[Dict] = None + self.default_query_limit = 50 + self.max_redash_cache_age = 86400 # One day + + @abc.abstractmethod + def get_redash_query_id(self, params: Dict) -> Optional[int]: + """ + Retrieves the query template that should be executed for the given + source / database / schema / table combination. + + Redash Connections are generally unique to the source and database. + For example, Snowflake account that has two databases would require two + separate connections in Redash. This would require at least one query + template per connection. + + The query ID can be found in the URL of the query when using the Redash GUI. + + :param params: A dictionary of input parameters containing the database, + cluster, schema and tableName + :returns: the ID for the query in Redash. Can be None if one does not exist. + """ + pass # pragma: no cover + + def _build_headers(self, params: Dict) -> None: + """ + Generates the headers to use for the API invocation. Attemps to use a + Query API key, if it exists, then falls back to a User API if no + query API key is returned. + + Background on Redash API keys: https://redash.io/help/user-guide/integrations-and-api/api + """ + api_key = self._get_query_api_key(params) or self.user_api_key + if api_key is None: + raise RedashApiKeyNotProvidedException('No API key provided') + self.headers = {"Authorization": "Key {}".format(api_key)} + + def _get_query_api_key(self, params: Dict) -> Optional[str]: + """ + This function can be overridden by sub classes to look up the specific + API key to use for a given database / cluster / schema / table combination. + """ + return None + + def get_select_fields(self, params: Dict) -> str: + """ + Allows customization of the fields in the select clause. This can be used to + return a subset of fields or to apply functions (e.g. to mask data) on a + table by table basis. Defaults to `*` for all fields. + + This string should be valid SQL AND fit BETWEEN the brackets `SELECT {} FROM ...` + + :param params: A dictionary of input parameters containing the database, + cluster, schema and tableName + :returns: a string corresponding to fields to select in the query + """ + return '*' + + def get_where_clause(self, params: Dict) -> str: + """ + Allows customization of the 'WHERE' clause to be provided for each set of parameters + by the client implementation. Defaults to an empty string. + """ + return '' + + def build_redash_query_params(self, params: Dict) -> Dict: + """ + Builds a dictionary of parameters that will be injected into the Redash query + template. The keys in this dictionary MUST be a case-sensitive match to the + template names in the Redash query and you MUST have the exact same parameters, + no more, no less. + + Override this function to provide custom values. + """ + return { + 'parameters': { + 'SELECT_FIELDS': self.get_select_fields(params), + 'SCHEMA_NAME': params.get('schema'), + 'TABLE_NAME': params.get('tableName'), + 'WHERE_CLAUSE': self.get_where_clause(params), + 'RCD_LIMIT': str(self.default_query_limit) + }, + 'max_age': self.max_redash_cache_age + } + + def _start_redash_query(self, query_id: int, query_params: Dict) -> Tuple[Any, bool]: + """ + Starts a query in Redash. Returns a job ID that can be used to poll for + the job status. + + :param query_id: The ID of the query in the Redash system. This can + be retrieved by viewing the URL for your query template in the + Redash GUI. + :param query_params: A dictionary of parameters to inject into the + corresponding query's template + :return: A tuple of the response object and boolean. The response object + changes based off of whether or not the result from Redash came from + the cache. + The boolean is True if the result came from the Redash cache, otherwise False. + """ + url_inputs = {'redash_host': self.redash_host, 'query_id': query_id} + query_url = REDASH_SUBMIT_QUERY_ENDPOINT.format(**url_inputs) + + resp = r.post(query_url, json=query_params, headers=self.headers) + resp_json = resp.json() + + LOGGER.debug('Response from redash query: %s', resp_json) + + # When submitting a query, Redash can return 2 distinct payloads. One if the + # query result has been cached by Redash and one if the query was submitted + # to be executed. The 'job' object is returned if the query is not cached. + if 'job' in resp_json: + redash_cached = False + else: + redash_cached = True + + return resp_json, redash_cached + + def _wait_for_query_finish(self, job_id: str, max_wait: int = 60) -> str: + """ + Waits for the query to finish and validates that a successful response is returned. + + :param job_id: the ID for the job executing the query + :return: a query result ID tha can be used to fetch the results + """ + url_inputs = {'redash_host': self.redash_host, 'job_id': job_id} + query_url = REDASH_TRACK_JOB_ENDPOINT.format(**url_inputs) + + query_result_id: Optional[str] = None + max_time = time.time() + max_wait + + while time.time() < max_time: + resp = r.get(query_url, headers=self.headers) + resp_json = resp.json() + + LOGGER.debug('Received response from Redash job %s: %s', job_id, resp_json) + + job_info = resp_json['job'] + job_status = RedashApiResponse(job_info['status']) + + if job_status == RedashApiResponse.SUCCESS: + query_result_id = job_info['query_result_id'] + break + + elif job_status == RedashApiResponse.FAILURE: + raise RedashQueryCouldNotCompleteException(job_info['error']) + time.sleep(.5) + + if query_result_id is None: + raise RedashQueryCouldNotCompleteException('Query execution took too long') + + return query_result_id + + def _get_query_results(self, query_result_id: str) -> Dict: + """ + Retrieves query results from a successful query run + + :param query_result_id: ID returned by Redash after a successful query execution + :return: A Redash response dictionary + """ + url_inputs = {'redash_host': self.redash_host, 'query_result_id': query_result_id} + results_url = REDASH_QUERY_RESULTS_ENDPOINT.format(**url_inputs) + resp = r.get(results_url, headers=self.headers) + return resp.json() + + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + """ + Returns a FlaskResponse object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + LOGGER.debug('Retrieving preview data from Redash with params: %s', params) + try: + query_id = self.get_redash_query_id(params) + if query_id is None: + raise RedashQueryTemplateDoesNotExistForResource('Could not find query for params: %s', params) + + # Build headers to use the Query API key or User API key + self._build_headers(params) + + query_params = self.build_redash_query_params(params) + query_results, cached_result = self._start_redash_query(query_id=query_id, query_params=query_params) + + # Redash attempts to use internal caching. The format of the response + # changes based on whether or not a cached response is returned + if not cached_result: + query_result_id = self._wait_for_query_finish(job_id=query_results['job']['id']) + query_results = self._get_query_results(query_result_id=query_result_id) + + columns = [ColumnItem(c['name'], c['type']) for c in query_results['query_result']['data']['columns']] + preview_data = PreviewData(columns, query_results['query_result']['data']['rows']) + + data = PreviewDataSchema().dump(preview_data) + PreviewDataSchema().load(data) # for validation only + payload = jsonify({'preview_data': data}) + return make_response(payload, HTTPStatus.OK) + + except Exception as e: + LOGGER.error('ERROR getting Redash preview: %s', e) + return make_response(jsonify({'preview_data': {}}), HTTPStatus.INTERNAL_SERVER_ERROR) + + def get_feature_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + pass diff --git a/frontend/amundsen_application/base/base_s3_preview_client.py b/frontend/amundsen_application/base/base_s3_preview_client.py new file mode 100644 index 0000000000..b0c14a9e8e --- /dev/null +++ b/frontend/amundsen_application/base/base_s3_preview_client.py @@ -0,0 +1,48 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging +from http import HTTPStatus +from typing import Dict, Optional + +from amundsen_application.base.base_preview_client import BasePreviewClient +from amundsen_application.models.preview_data import (PreviewData, + PreviewDataSchema) +from flask import Response as FlaskResponse +from flask import jsonify, make_response +from marshmallow import ValidationError + + +class BaseS3PreviewClient(BasePreviewClient): + def __init__(self) -> None: + pass + + @abc.abstractmethod + def get_s3_preview_data(self, *, params: Dict) -> PreviewData: + """ + Returns the data from S3 in PreviewData model format + """ + pass # pragma: no cover + + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + try: + preview_data = self.get_s3_preview_data(params=params) + try: + data = PreviewDataSchema().dump(preview_data) + PreviewDataSchema().load(data) # for validation only + payload = jsonify({'preview_data': data}) + return make_response(payload, HTTPStatus.OK) + except ValidationError as err: + logging.error("PreviewDataSchema serialization error " + str(err.messages)) + return make_response(jsonify({'preview_data': {}}), HTTPStatus.INTERNAL_SERVER_ERROR) + except Exception as err: + logging.error("error getting s3 preview data " + str(err)) + return make_response(jsonify({'preview_data': {}}), HTTPStatus.INTERNAL_SERVER_ERROR) + + def get_feature_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + """ + BaseS3PreviewClient only supports data preview currently but this function needs to be stubbed to + implement the BasePreviewClient interface + """ + pass diff --git a/frontend/amundsen_application/base/base_superset_preview_client.py b/frontend/amundsen_application/base/base_superset_preview_client.py new file mode 100644 index 0000000000..c4336bbd52 --- /dev/null +++ b/frontend/amundsen_application/base/base_superset_preview_client.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc +import logging + +from flask import Response as FlaskResponse, make_response, jsonify +from http import HTTPStatus +from marshmallow import ValidationError +from requests import Response +from typing import Dict, Optional + +from amundsen_application.base.base_preview_client import BasePreviewClient +from amundsen_application.models.preview_data import ColumnItem, PreviewData, PreviewDataSchema + + +class BaseSupersetPreviewClient(BasePreviewClient): + @abc.abstractmethod + def __init__(self) -> None: + self.headers = {} # type: Dict + + @abc.abstractmethod + def post_to_sql_json(self, *, params: Dict, headers: Dict) -> Response: + """ + Returns the post response from Superset's `sql_json` endpoint + """ + pass # pragma: no cover + + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + """ + Returns a FlaskResponse object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + try: + # Clone headers so that it does not mutate instance's state + headers = dict(self.headers) + + # Merge optionalHeaders into headers + if optionalHeaders is not None: + headers.update(optionalHeaders) + + # Request preview data + response = self.post_to_sql_json(params=params, headers=headers) + + # Verify and return the results + response_dict = response.json() + columns = [ColumnItem(c['name'], c['type']) for c in response_dict['columns']] + preview_data = PreviewData(columns, response_dict['data']) + try: + data = PreviewDataSchema().dump(preview_data) + PreviewDataSchema().load(data) # for validation only + payload = jsonify({'preview_data': data}) + return make_response(payload, response.status_code) + except ValidationError as err: + logging.error("PreviewDataSchema serialization error " + str(err.messages)) + return make_response(jsonify({'preview_data': {}}), HTTPStatus.INTERNAL_SERVER_ERROR) + except Exception: + return make_response(jsonify({'preview_data': {}}), HTTPStatus.INTERNAL_SERVER_ERROR) + + def get_feature_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> FlaskResponse: + pass diff --git a/frontend/amundsen_application/base/examples/__init__.py b/frontend/amundsen_application/base/examples/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/base/examples/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/base/examples/example_announcement_client.py b/frontend/amundsen_application/base/examples/example_announcement_client.py new file mode 100644 index 0000000000..32aa76b147 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_announcement_client.py @@ -0,0 +1,79 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from random import randint +from datetime import datetime, timedelta + +from amundsen_application.models.announcements import Announcements, Post +from amundsen_application.base.base_announcement_client import BaseAnnouncementClient + +try: + from sqlalchemy import Column, Integer, String, DateTime, create_engine + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.inspection import inspect + from sqlalchemy.orm import sessionmaker +except ModuleNotFoundError: + pass + +Base = declarative_base() + + +class DBAnnouncement(Base): # type: ignore + __tablename__ = 'announcements' + + id = Column(Integer, primary_key=True) + + date = Column(DateTime) + title = Column(String) + content = Column(String) + + +class SQLAlchemyAnnouncementClient(BaseAnnouncementClient): + def __init__(self) -> None: + self._setup_mysql() + + def _setup_mysql(self) -> None: + self.engine = create_engine('sqlite:////tmp/amundsen.db', echo=True) + + session = sessionmaker(bind=self.engine)() + + # add dummy announcements to preview + if not inspect(self.engine).has_table(DBAnnouncement.__tablename__): + Base.metadata.create_all(self.engine) + + announcements = [] + + dummy_announcement = """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec at dapibus lorem. + Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. + Suspendisse est lectus, bibendum vitae vestibulum vitae, commodo eu tortor. + Sed rhoncus augue eget turpis interdum, eu aliquam lectus ornare. Aenean tempus in mauris vitae viverra. + """ + + for i in range(randint(5, 9)): + announcement = DBAnnouncement(id=i + 1, + date=datetime.now() + timedelta(days=i + 1), + title=f'Test announcement title {i + 1}', + content=dummy_announcement) + + announcements.append(announcement) + + session.add_all(announcements) + session.commit() + + def get_posts(self) -> Announcements: + """ + Returns an instance of amundsen_application.models.announcements.Announcements, which should match + amundsen_application.models.announcements.AnnouncementsSchema + """ + session = sessionmaker(bind=self.engine)() + + posts = [] + + for row in session.query(DBAnnouncement).order_by(DBAnnouncement.date.desc()): + post = Post(title=row.title, + date=row.date.strftime('%b %d %Y %H:%M:%S'), + html_content=row.content) + posts.append(post) + + return Announcements(posts) diff --git a/frontend/amundsen_application/base/examples/example_bigquery_preview_client.py b/frontend/amundsen_application/base/examples/example_bigquery_preview_client.py new file mode 100644 index 0000000000..607fc2a173 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_bigquery_preview_client.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from amundsen_application.base.base_bigquery_preview_client import BaseBigqueryPreviewClient +from amundsen_application.models.preview_data import ( + PreviewData, +) +from google.cloud import bigquery +from flatten_dict import flatten + + +class BigqueryPreviewClient(BaseBigqueryPreviewClient): + """ + Returns a Response object, where the response data represents a json object + with the preview data accessible on 'preview_data' key. The preview data should + match amundsen_application.models.preview_data.PreviewDataSchema + """ + + def __init__(self) -> None: + # Requires access to a service account eg. + # GOOGLE_APPLICATION_CREDENTIALS=path/serviceaccount.json or a mounted service kubernetes service account. + super().__init__(bq_client=bigquery.Client("your project here")) + + def _bq_list_rows( + self, gcp_project_id: str, table_project_name: str, table_name: str + ) -> PreviewData: + """ + Returns PreviewData from bigquery list rows api. + """ + table_id = f"{gcp_project_id}.{table_project_name}.{table_name}" + rows = self.bq_client.list_rows(table_id, max_results=self.preview_limit) + + # Make flat key ColumnItems from table schema. + columns = [] + for field in rows.schema: + extend_with = self._column_item_from_bq_schema(field) + columns.extend(extend_with) + + # Flatten rows and set missing empty keys to None, to avoid errors with undefined values + # in frontend + column_data = [] + for row in rows: + flat_row = flatten(dict(row), reducer="dot") + for key in columns: + if key.column_name not in flat_row: + flat_row[key.column_name] = None + column_data.append(flat_row) + + return PreviewData(columns, column_data) diff --git a/frontend/amundsen_application/base/examples/example_dremio_preview_client.py b/frontend/amundsen_application/base/examples/example_dremio_preview_client.py new file mode 100644 index 0000000000..b16ec1f4b4 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_dremio_preview_client.py @@ -0,0 +1,96 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from http import HTTPStatus +import logging +from typing import Dict, Optional # noqa: F401 + +from flask import Response, jsonify, make_response, current_app as app +from marshmallow import ValidationError +from pyarrow import flight + +from amundsen_application.base.base_superset_preview_client import BasePreviewClient +from amundsen_application.models.preview_data import PreviewData, PreviewDataSchema, ColumnItem + + +class _DremioAuthHandler(flight.ClientAuthHandler): + """ClientAuthHandler for connections to Dremio server endpoint. + """ + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + super(flight.ClientAuthHandler, self).__init__() + + def authenticate(self, outgoing: flight.ClientAuthSender, + incoming: flight.ClientAuthReader) -> None: + """Authenticate with Dremio user credentials. + """ + basic_auth = flight.BasicAuth(self.username, self.password) + outgoing.write(basic_auth.serialize()) + self.token = incoming.read() + + def get_token(self,) -> str: + """Get the token from this AuthHandler. + """ + return self.token + + +class DremioPreviewClient(BasePreviewClient): + + SQL_STATEMENT = 'SELECT * FROM {schema}."{table}" LIMIT 50' + + def __init__(self,) -> None: + self.url = app.config['PREVIEW_CLIENT_URL'] + self.username = app.config['PREVIEW_CLIENT_USERNAME'] + self.password = app.config['PREVIEW_CLIENT_PASSWORD'] + + self.connection_args: Dict[str, bytes] = {} + tls_root_certs_path = app.config['PREVIEW_CLIENT_CERTIFICATE'] + if tls_root_certs_path is not None: + with open(tls_root_certs_path, "rb") as f: + self.connection_args["tls_root_certs"] = f.read() + + def get_preview_data(self, params: Dict, optionalHeaders: Optional[Dict] = None) -> Response: + """Preview data from Dremio source + """ + database = params.get('database') + if database != 'DREMIO': + logging.info('Skipping table preview for non-Dremio table') + return make_response(jsonify({'preview_data': {}}), HTTPStatus.OK) + + try: + # Format base SQL_STATEMENT with request table and schema + schema = '"{}"'.format(params['schema'].replace('.', '"."')) + table = params['tableName'] + sql = DremioPreviewClient.SQL_STATEMENT.format(schema=schema, + table=table) + + client = flight.FlightClient(self.url, **self.connection_args) + client.authenticate(_DremioAuthHandler(self.username, self.password)) + flight_descriptor = flight.FlightDescriptor.for_command(sql) + flight_info = client.get_flight_info(flight_descriptor) + reader = client.do_get(flight_info.endpoints[0].ticket) + + result = reader.read_all() + names = result.schema.names + types = result.schema.types + + columns = map(lambda x: x.to_pylist(), result.columns) + rows = [dict(zip(names, row)) for row in zip(*columns)] + column_items = [ColumnItem(n, t) for n, t in zip(names, types)] + + preview_data = PreviewData(column_items, rows) + try: + data = PreviewDataSchema().dump(preview_data) + PreviewDataSchema().load(data) # for validation only + payload = jsonify({'preview_data': data}) + return make_response(payload, HTTPStatus.OK) + except ValidationError as err: + logging.error(f'Error(s) occurred while building preview data: {err.messages}') + payload = jsonify({'preview_data': {}}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) + + except Exception as e: + logging.error(f'Encountered exception: {e}') + payload = jsonify({'preview_data': {}}) + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/base/examples/example_mail_client.py b/frontend/amundsen_application/base/examples/example_mail_client.py new file mode 100644 index 0000000000..191a9051d1 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_mail_client.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import smtplib + +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from http import HTTPStatus +from typing import Dict, List, Optional + +from flask import Response, jsonify, make_response + +from amundsen_application.base.base_mail_client import BaseMailClient + + +# based on https://stackoverflow.com/a/6270987 +class MailClient(BaseMailClient): + def __init__(self, recipients: List[str]) -> None: + self.recipients = recipients + + def send_email(self, + html: str, + subject: str, + optional_data: Optional[Dict] = None, + recipients: Optional[List[str]] = None, + sender: Optional[str] = None) -> Response: + if not sender: + sender = os.environ.get('AMUNDSEN_EMAIL') or '' # set me + if not recipients: + recipients = self.recipients + + sender_pass = os.environ.get('AMUNDSEN_EMAIL_PASSWORD') or '' # set me + + # Create message container - the correct MIME type + # to combine text and html is multipart/alternative. + msg = MIMEMultipart('alternative') + msg['Subject'] = subject + msg['From'] = sender + msg['To'] = ', '.join(recipients) + + # Record the MIME type of text/html + # and attach parts to message container. + msg.attach(MIMEText(html, 'html')) + + s = smtplib.SMTP('smtp.gmail.com') + try: + s.connect('smtp.gmail.com', 587) + s.ehlo() + s.starttls() + s.ehlo() + s.login(sender, sender_pass) + message = s.send_message(msg) + payload = jsonify({'msg': message}) + s.quit() + return make_response(payload, HTTPStatus.OK) + except Exception as e: + err_message = 'Encountered exception: ' + str(e) + logging.exception(err_message) + payload = jsonify({'msg': err_message}) + s.quit() + return make_response(payload, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/frontend/amundsen_application/base/examples/example_redash_preview_client.py b/frontend/amundsen_application/base/examples/example_redash_preview_client.py new file mode 100644 index 0000000000..3a1c58baa2 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_redash_preview_client.py @@ -0,0 +1,118 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os + +from typing import Dict, Optional + +from amundsen_application.base.base_redash_preview_client import BaseRedashPreviewClient + + +LOGGER = logging.getLogger(__name__) + + +# Redash natively runs on port 5000, the same port as Amundsen. +# Make sure to update the running port to match your deployment! +DEFAULT_URL = 'http://localhost:5010' + + +# Update this mapping with your database.cluster and Redash query ID +SOURCE_DB_QUERY_MAP = { + 'snowflake.ca_covid': 1 +} + +# This example uses a common, system user, for the API key +REDASH_USER_API_KEY = os.environ.get('REDASH_USER_API_KEY', '') + + +def _build_db_cluster_key(params: Dict) -> str: + _db = params.get('database') + _cluster = params.get('cluster') + + db_cluster_key = f'{_db}.{_cluster}' + return db_cluster_key + + +class RedashSimplePreviewClient(BaseRedashPreviewClient): + def __init__(self, + *, + redash_host: str = DEFAULT_URL, + user_api_key: Optional[str] = REDASH_USER_API_KEY) -> None: + super().__init__(redash_host=redash_host, user_api_key=user_api_key) + + def get_redash_query_id(self, params: Dict) -> Optional[int]: + """ + Retrieves the query template that should be executed for the given + source / database / schema / table combination. + + Redash Connections are generally unique to the source and database. + For example, Snowflake account that has two databases would require two + separate connections in Redash. This would require at least one query + template per connection. + + The query ID can be found in the URL of the query when using the Redash GUI. + """ + db_cluster_key = _build_db_cluster_key(params) + return SOURCE_DB_QUERY_MAP.get(db_cluster_key) + + +class RedashComplexPreviewClient(BaseRedashPreviewClient): + def __init__(self, + *, + redash_host: str = DEFAULT_URL, + user_api_key: Optional[str] = REDASH_USER_API_KEY) -> None: + super().__init__(redash_host=redash_host, user_api_key=user_api_key) + self.default_query_limit = 100 + self.max_redash_cache_age = 3600 # One Hour + + def _get_query_api_key(self, params: Dict) -> Optional[str]: + if params.get('database') in ['redshift']: + return os.environ.get('REDSHIFT_USER_API_KEY', '') + return None + + def get_redash_query_id(self, params: Dict) -> Optional[int]: + db_cluster_key = _build_db_cluster_key(params) + return SOURCE_DB_QUERY_MAP.get(db_cluster_key) + + def get_select_fields(self, params: Dict) -> str: + """ + Manually defining the dictionary in this function for readability + """ + # These are sample values to show how table-level select clauses work + field_select_vals = { + 'snowflake.ca_covid': { + 'open_data.case_demographics_age': ( + "date, SUBSTR(age_group, 0, 2) || '******' as age_group, totalpositive, case_percent, ca_percent" + ), + 'open_data.statewide_testing': 'date, tested' + } + } + + db_cluster_key = _build_db_cluster_key(params) + schema_tbl_key = f"{params.get('schema')}.{params.get('tableName')}" + + # Always returns a value, defaults to '*' if nothing is defined + return field_select_vals.get(db_cluster_key, {}).get(schema_tbl_key, '*') + + def get_where_clause(self, params: Dict) -> str: + """ + MUST return the entire where clause, including the word "where" + """ + where_vals = { + 'snowflake.ca_covid': { + 'open_data.case_demographics_age': "totalpositive < 120", + } + } + + db_cluster_key = _build_db_cluster_key(params) + schema_tbl_key = f"{params.get('schema')}.{params.get('tableName')}" + + # Always returns a value, defaults to an empty string ('') if nothing is defined + where_clause = where_vals.get(db_cluster_key, {}).get(schema_tbl_key, '') + + # Add the word where if a custom where clause is applied + if where_clause: + where_clause = f'WHERE {where_clause}' + + return where_clause diff --git a/frontend/amundsen_application/base/examples/example_s3_json_preview_client.py b/frontend/amundsen_application/base/examples/example_s3_json_preview_client.py new file mode 100644 index 0000000000..d6d544d1af --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_s3_json_preview_client.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from typing import Dict + +import boto3 +from amundsen_application.base.base_s3_preview_client import \ + BaseS3PreviewClient +from amundsen_application.models.preview_data import ColumnItem, PreviewData + + +class S3JSONPreviewClient(BaseS3PreviewClient): + """ + S3JSONPreviewClient is an S3 Preview Client that: + 1. Gets JSON files from S3 that are stored in a bucket with keys preview_data/{schema}/{table}.json + 2. Converts the JSON values to PreviewData model + 3. Returns the serialized model + + In order for this preview client to work you must: + - Have S3 files stored in a bucket with keys 'preview_data/{schema}/{table}.json' + - Files are formatted as list of rows as map with key being the column name and value being column value + Ex: + [ + { + 'col1': 1, + 'col2': '2' + }, + { + 'col1': 3, + 'col2': '4' + } + ... + ] + - Nested field are not supported. We suggest flattening your nested fields. + Ex: + [ + { + 'col1': { + 'col2: 1 + } + ] + should be: + [ + { + 'col1.col2': 1 + } + ] + - Run your frontend service with an IAM Profile that has s3:GetObject permissions on the 'preview_data/' prefix + """ + + def __init__(self) -> None: + self.s3 = boto3.client("s3") + bucket = os.getenv("PREVIEW_CLIENT_S3_BUCKET") + if bucket == "": + raise Exception("When using the S3CSVPreviewClient you must set the PREVIEW_CLIENT_S3_BUCKET environment " + "variable to point to where your preview_data CSVs are stored.") + self.s3_bucket = bucket + + def get_s3_preview_data(self, *, params: Dict) -> PreviewData: + schema = params.get("schema") + table = params.get("tableName") + + try: + obj = self.s3.get_object(Bucket=self.s3_bucket, Key=f"preview_data/{schema}/{table}.json") + except Exception as e: + raise Exception(f"Error getting object from s3. preview_data/{schema}/{table}.json" + f"Caused by: {e}") + + data = json.loads(obj['Body'].read().decode('utf-8')) + columns = [ColumnItem(col_name, '') for col_name in data[0]] # TODO: figure out how to do Type. Is it needed? + return PreviewData(columns=columns, data=data) diff --git a/frontend/amundsen_application/base/examples/example_superset_preview_client.py b/frontend/amundsen_application/base/examples/example_superset_preview_client.py new file mode 100644 index 0000000000..5c03a39b07 --- /dev/null +++ b/frontend/amundsen_application/base/examples/example_superset_preview_client.py @@ -0,0 +1,59 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import requests +import uuid + +from requests import Response +from typing import Any, Dict # noqa: F401 + +from amundsen_application.base.base_superset_preview_client import BaseSupersetPreviewClient + +# 'main' is an existing default Superset database which serves for demo purposes +DEFAULT_DATABASE_MAP = { + 'main': 1, +} +DEFAULT_URL = 'http://localhost:8088/superset/sql_json/' + + +class SupersetPreviewClient(BaseSupersetPreviewClient): + def __init__(self, + *, + database_map: Dict[str, int] = DEFAULT_DATABASE_MAP, + url: str = DEFAULT_URL) -> None: + self.database_map = database_map + self.headers = {} + self.url = url + + def post_to_sql_json(self, *, params: Dict, headers: Dict) -> Response: + """ + Returns the post response from Superset's `sql_json` endpoint + """ + # Create the appropriate request data + try: + request_data = {} # type: Dict[str, Any] + + # Superset's sql_json endpoint requires a unique client_id + request_data['client_id'] = uuid.uuid4() + + # Superset's sql_json endpoint requires the id of the database that it will execute the query on + database_name = 'main' # OR params.get('database') in a real use case + request_data['database_id'] = self.database_map.get(database_name, '') + + # Generate the sql query for the desired data preview content + try: + # 'main' is an existing default Superset schema which serves for demo purposes + schema = 'main' # OR params.get('schema') in a real use case + + # 'ab_role' is an existing default Superset table which serves for demo purposes + table_name = 'ab_role' # OR params.get('tableName') in a real use case + + request_data['sql'] = 'SELECT * FROM {schema}.{table} LIMIT 50'.format(schema=schema, table=table_name) + except Exception as e: + logging.error('Encountered error generating request sql: ' + str(e)) + except Exception as e: + logging.error('Encountered error generating request data: ' + str(e)) + + # Post request to Superset's `sql_json` endpoint + return requests.post(self.url, data=request_data, headers=headers) diff --git a/frontend/amundsen_application/config.py b/frontend/amundsen_application/config.py new file mode 100644 index 0000000000..19d9872a28 --- /dev/null +++ b/frontend/amundsen_application/config.py @@ -0,0 +1,201 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import distutils.util +from typing import Callable, Dict, List, Optional, Set # noqa: F401 +from amundsen_application.models.user import User + +from flask import Flask # noqa: F401 + +from amundsen_application.tests.test_utils import get_test_user + + +class MatchRuleObject: + def __init__(self, + schema_regex: Optional[str] = None, + table_name_regex: Optional[str] = None, + ) -> None: + self.schema_regex = schema_regex + self.table_name_regex = table_name_regex + + +class Config: + LOG_FORMAT = '%(asctime)s.%(msecs)03d [%(levelname)s] %(module)s.%(funcName)s:%(lineno)d (%(process)d:' \ + + '%(threadName)s) - %(message)s' + LOG_DATE_FORMAT = '%Y-%m-%dT%H:%M:%S%z' + LOG_LEVEL = 'INFO' + + # Path to the logging configuration file to be used by `fileConfig()` method + # https://docs.python.org/3.7/library/logging.config.html#logging.config.fileConfig + # LOG_CONFIG_FILE = 'amundsen_application/logging.conf' + LOG_CONFIG_FILE = None + + COLUMN_STAT_ORDER = None # type: Dict[str, int] + + # The following three variables control whether table descriptions can be edited via the UI + # ALL_UNEDITABLE_SCHEMAS: set environment variable to 'true' if you don't want any schemas to be able to be edited + # UNEDITABLE_SCHEMAS: a set of schema names whose tables will not be editable + # UNEDITABLE_TABLE_DESCRIPTION_MATCH_RULES: a list of regex rules for schema name, table name, or both + # See https://www.amundsen.io/amundsen/frontend/docs/flask_config/#uneditable-table-descriptions for more info! + ALL_UNEDITABLE_SCHEMAS = os.getenv('ALL_UNEDITABLE_SCHEMAS', 'false') == 'true' # type: bool + UNEDITABLE_SCHEMAS = set() # type: Set[str] + UNEDITABLE_TABLE_DESCRIPTION_MATCH_RULES = [] # type: List[MatchRuleObject] + + # DEPRECATED (since version 3.9.0): Please use `POPULAR_RESOURCES_COUNT` + # Number of popular tables to be displayed on the index/search page + POPULAR_TABLE_COUNT = None + POPULAR_RESOURCES_COUNT = 4 # type: int + + # DEPRECATED (since version 3.9.0): Please use `POPULAR_RESOURCES_PERSONALIZATION` + # Personalize the popular tables response for the current authenticated user + POPULAR_TABLE_PERSONALIZATION = None + POPULAR_RESOURCES_PERSONALIZATION = False # type: bool + + # Request Timeout Configurations in Seconds + REQUEST_SESSION_TIMEOUT_SEC = 3 + + # Frontend Application + FRONTEND_BASE = '' + + # JS config override for frontend app + JS_CONFIG_OVERRIDE_ENABLED = False + + # Search Service + SEARCHSERVICE_REQUEST_CLIENT = None + SEARCHSERVICE_REQUEST_HEADERS = None + SEARCHSERVICE_BASE = '' + + # Metadata Service + METADATASERVICE_REQUEST_CLIENT = None + METADATASERVICE_REQUEST_HEADERS = None + METADATASERVICE_BASE = '' + + # Mail Client Features + MAIL_CLIENT = None + NOTIFICATIONS_ENABLED = False + + # Initialize custom routes + INIT_CUSTOM_ROUTES = None # type: Callable[[Flask], None] + + # Settings for Preview Client integration + PREVIEW_CLIENT_ENABLED = os.getenv('PREVIEW_CLIENT_ENABLED') == 'true' # type: bool + # Maps to a class path and name + PREVIEW_CLIENT = os.getenv('PREVIEW_CLIENT', None) # type: Optional[str] + PREVIEW_CLIENT_URL = os.getenv('PREVIEW_CLIENT_URL') # type: Optional[str] + PREVIEW_CLIENT_USERNAME = os.getenv('PREVIEW_CLIENT_USERNAME') # type: Optional[str] + PREVIEW_CLIENT_PASSWORD = os.getenv('PREVIEW_CLIENT_PASSWORD') # type: Optional[str] + PREVIEW_CLIENT_CERTIFICATE = os.getenv('PREVIEW_CLIENT_CERTIFICATE') # type: Optional[str] + + # Settings for Quality client + QUALITY_CLIENT = os.getenv('QUALITY_CLIENT', None) # type: Optional[str] + + # Settings for Announcement Client integration + ANNOUNCEMENT_CLIENT_ENABLED = os.getenv('ANNOUNCEMENT_CLIENT_ENABLED') == 'true' # type: bool + # Maps to a class path and name + ANNOUNCEMENT_CLIENT = os.getenv('ANNOUNCEMENT_CLIENT', None) # type: Optional[str] + + # Settings for resource Notice client + NOTICE_CLIENT = os.getenv('NOTICE_CLIENT', None) # type: Optional[str] + + # Settings for Issue tracker integration + ISSUE_LABELS = [] # type: List[str] + ISSUE_TRACKER_API_TOKEN = None # type: str + ISSUE_TRACKER_URL = None # type: str + ISSUE_TRACKER_USER = None # type: str + ISSUE_TRACKER_PASSWORD = None # type: str + ISSUE_TRACKER_PROJECT_ID = None # type: int + # Maps to a class path and name + ISSUE_TRACKER_CLIENT = None # type: str + ISSUE_TRACKER_CLIENT_ENABLED = False # type: bool + # Max issues to display at a time + ISSUE_TRACKER_MAX_RESULTS = None # type: int + # Override issue type ID for cloud Jira deployments + ISSUE_TRACKER_ISSUE_TYPE_ID = None + + # Programmatic Description configuration. Please see docs/flask_config.md + PROGRAMMATIC_DISPLAY = None # type: Optional[Dict] + + # If specified, will be used to generate headers for service-to-service communication + # Please note that if specified, this will ignore following config properties: + # 1. METADATASERVICE_REQUEST_HEADERS + # 2. SEARCHSERVICE_REQUEST_HEADERS + REQUEST_HEADERS_METHOD: Optional[Callable[[Flask], Optional[Dict]]] = None + + AUTH_USER_METHOD: Optional[Callable[[Flask], User]] = None + GET_PROFILE_URL = None + + # For additional preview client, register more at DefaultPreviewMethodFactory.__init__() + # For any private preview client, use custom factory that implements BasePreviewMethodFactory + DASHBOARD_PREVIEW_FACTORY = None # By default DefaultPreviewMethodFactory will be used. + DASHBOARD_PREVIEW_IMAGE_CACHE_MAX_AGE_SECONDS = 60 * 60 * 24 * 1 # 1 day + + CREDENTIALS_MODE_ADMIN_TOKEN = os.getenv('CREDENTIALS_MODE_ADMIN_TOKEN', None) + CREDENTIALS_MODE_ADMIN_PASSWORD = os.getenv('CREDENTIALS_MODE_ADMIN_PASSWORD', None) + MODE_ORGANIZATION = None + MODE_REPORT_URL_TEMPLATE = None + # Add Preview class name below to enable ACL, assuming it is supported by the Preview class + # e.g: ACL_ENABLED_DASHBOARD_PREVIEW = {'ModePreview'} + ACL_ENABLED_DASHBOARD_PREVIEW = set() # type: Set[Optional[str]] + + MTLS_CLIENT_CERT = os.getenv('MTLS_CLIENT_CERT') + """ + Optional. + The path to a PEM formatted certificate to present when calling the metadata and search services. + MTLS_CLIENT_KEY must also be set. + """ + + MTLS_CLIENT_KEY = os.getenv('MTLS_CLIENT_KEY') + """Optional. The path to a PEM formatted key to use with the MTLS_CLIENT_CERT. MTLS_CLIENT_CERT must also be set.""" + + +class LocalConfig(Config): + DEBUG = False + TESTING = False + LOG_LEVEL = 'DEBUG' + + FRONTEND_PORT = '5000' + # If installing locally directly from the github source + # modify these ports if necessary to point to you local search and metadata services + SEARCH_PORT = '5001' + METADATA_PORT = '5002' + + # If installing using the Docker bootstrap, this should be modified to the docker host ip. + LOCAL_HOST = '0.0.0.0' + + JS_CONFIG_OVERRIDE_ENABLED = bool(distutils.util.strtobool(os.environ.get('JS_CONFIG_OVERRIDE_ENABLED', 'False'))) + + FRONTEND_BASE = os.environ.get('FRONTEND_BASE', + 'http://{LOCAL_HOST}:{PORT}'.format( + LOCAL_HOST=LOCAL_HOST, + PORT=FRONTEND_PORT) + ) + + SEARCHSERVICE_BASE = os.environ.get('SEARCHSERVICE_BASE', + 'http://{LOCAL_HOST}:{PORT}'.format( + LOCAL_HOST=LOCAL_HOST, + PORT=SEARCH_PORT) + ) + + METADATASERVICE_BASE = os.environ.get('METADATASERVICE_BASE', + 'http://{LOCAL_HOST}:{PORT}'.format( + LOCAL_HOST=LOCAL_HOST, + PORT=METADATA_PORT) + ) + + +class TestConfig(LocalConfig): + POPULAR_RESOURCES_PERSONALIZATION = True + AUTH_USER_METHOD = get_test_user + NOTIFICATIONS_ENABLED = True + ISSUE_TRACKER_URL = 'test_url' + ISSUE_TRACKER_USER = 'test_user' + ISSUE_TRACKER_PASSWORD = 'test_password' + ISSUE_TRACKER_PROJECT_ID = 1 + ISSUE_TRACKER_CLIENT_ENABLED = True + ISSUE_TRACKER_MAX_RESULTS = 3 + + +class TestNotificationsDisabledConfig(LocalConfig): + AUTH_USER_METHOD = get_test_user + NOTIFICATIONS_ENABLED = False diff --git a/frontend/amundsen_application/deprecations.py b/frontend/amundsen_application/deprecations.py new file mode 100644 index 0000000000..79a6d4a385 --- /dev/null +++ b/frontend/amundsen_application/deprecations.py @@ -0,0 +1,35 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import warnings + +from flask import Flask + +warnings.simplefilter('always', DeprecationWarning) + + +# Deprecation Warnings +def process_deprecations(app: Flask) -> None: + if os.getenv('APP_WRAPPER') or os.getenv('APP_WRAPPER_CLASS'): + warnings.warn("'APP_WRAPPER' and 'APP_WRAPPER_CLASS' variables are deprecated since version (3.9.0), " + "and will be removed in version 4. " + "Please use 'FLASK_APP_MODULE_NAME' and 'FLASK_APP_CLASS_NAME' instead", + DeprecationWarning) + + if os.getenv('APP_WRAPPER_ARGS'): + warnings.warn("'APP_WRAPPER_ARGS' variable is deprecated since version (3.9.0), " + "and will be removed in version 4. " + "Please use 'FLASK_APP_KWARGS_DICT' instead", DeprecationWarning) + + if app.config.get("POPULAR_TABLE_COUNT", None) is not None: + app.config["POPULAR_RESOURCES_COUNT"] = app.config["POPULAR_TABLE_COUNT"] + warnings.warn("'POPULAR_TABLE_COUNT' variable is deprecated since version (3.9.0), " + "and will be removed in version 4. " + "Please use 'POPULAR_RESOURCES_COUNT' instead", DeprecationWarning) + + if app.config.get("POPULAR_TABLE_PERSONALIZATION", None) is not None: + app.config["POPULAR_RESOURCES_PERSONALIZATION"] = app.config["POPULAR_TABLE_PERSONALIZATION"] + warnings.warn("'POPULAR_TABLE_PERSONALIZATION' variable is deprecated since version (3.9.0), " + "and will be removed in version 4. " + "Please use 'POPULAR_RESOURCES_PERSONALIZATION' instead", DeprecationWarning) diff --git a/frontend/amundsen_application/log/__init__.py b/frontend/amundsen_application/log/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/log/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/log/action_log.py b/frontend/amundsen_application/log/action_log.py new file mode 100644 index 0000000000..d81eb8da85 --- /dev/null +++ b/frontend/amundsen_application/log/action_log.py @@ -0,0 +1,89 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import functools +import getpass + +import json +import logging +import socket +from datetime import datetime, timezone, timedelta + +from typing import Any, Dict, Callable +from flask import current_app as flask_app +from amundsen_application.log import action_log_callback +from amundsen_application.log.action_log_model import ActionLogParams + +LOGGER = logging.getLogger(__name__) +EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) # use POSIX epoch + + +def action_logging(f: Callable) -> Any: + """ + Decorates function to execute function at the same time triggering action logger callbacks. + It will call action logger callbacks twice, one for pre-execution and the other one for post-execution. + Action logger will be called with ActionLogParams + + :param f: function instance + :return: wrapped function + """ + @functools.wraps(f) + def wrapper(*args: Any, + **kwargs: Any) -> Any: + """ + An wrapper for api functions. It creates ActionLogParams based on the function name, positional arguments, + and keyword arguments. + + :param args: A passthrough positional arguments. + :param kwargs: A passthrough keyword argument + """ + metrics = _build_metrics(f.__name__, *args, **kwargs) + action_log_callback.on_pre_execution(ActionLogParams(**metrics)) + output = None + try: + output = f(*args, **kwargs) + return output + except Exception as e: + metrics['error'] = e + raise + finally: + metrics['end_epoch_ms'] = get_epoch_millisec() + try: + metrics['output'] = json.dumps(output) + except Exception: + metrics['output'] = output + + action_log_callback.on_post_execution(ActionLogParams(**metrics)) + + return wrapper + + +def get_epoch_millisec() -> int: + return (datetime.now(timezone.utc) - EPOCH) // timedelta(milliseconds=1) + + +def _build_metrics(func_name: str, + *args: Any, + **kwargs: Any) -> Dict[str, Any]: + """ + Builds metrics dict from function args + :param func_name: + :param args: + :param kwargs: + :return: Dict that matches ActionLogParams variable + """ + + metrics = { + 'command': kwargs.get('command', func_name), + 'start_epoch_ms': get_epoch_millisec(), + 'host_name': socket.gethostname(), + 'pos_args_json': json.dumps(args), + 'keyword_args_json': json.dumps(kwargs), + } # type: Dict[str, Any] + + if flask_app.config['AUTH_USER_METHOD']: + metrics['user'] = flask_app.config['AUTH_USER_METHOD'](flask_app).email + else: + metrics['user'] = getpass.getuser() + + return metrics diff --git a/frontend/amundsen_application/log/action_log_callback.py b/frontend/amundsen_application/log/action_log_callback.py new file mode 100644 index 0000000000..d7d6ba01a1 --- /dev/null +++ b/frontend/amundsen_application/log/action_log_callback.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +""" +An Action Logger module. Singleton pattern has been applied into this module +so that registered callbacks can be used all through the same python process. +""" + +import logging +import sys +from typing import Callable, List # noqa: F401 + +from pkg_resources import iter_entry_points + +from amundsen_application.log.action_log_model import ActionLogParams + +LOGGER = logging.getLogger(__name__) + +__pre_exec_callbacks = [] # type: List[Callable] +__post_exec_callbacks = [] # type: List[Callable] + + +def register_pre_exec_callback(action_log_callback: Callable) -> None: + """ + Registers more action_logger function callback for pre-execution. This function callback is expected to be called + with keyword args. For more about the arguments that is being passed to the callback, refer to + amundsen_application.log.action_log_model.ActionLogParams + :param action_logger: An action logger callback function + :return: None + """ + LOGGER.debug("Adding {} to pre execution callback".format(action_log_callback)) + __pre_exec_callbacks.append(action_log_callback) + + +def register_post_exec_callback(action_log_callback: Callable) -> None: + """ + Registers more action_logger function callback for post-execution. This function callback is expected to be + called with keyword args. For more about the arguments that is being passed to the callback, + amundsen_application.log.action_log_model.ActionLogParams + :param action_logger: An action logger callback function + :return: None + """ + LOGGER.debug("Adding {} to post execution callback".format(action_log_callback)) + __post_exec_callbacks.append(action_log_callback) + + +def on_pre_execution(action_log_params: ActionLogParams) -> None: + """ + Calls callbacks before execution. + Note that any exception from callback will be logged but won't be propagated. + :param kwargs: + :return: None + """ + LOGGER.debug("Calling callbacks: {}".format(__pre_exec_callbacks)) + for call_back_function in __pre_exec_callbacks: + try: + call_back_function(action_log_params) + except Exception: + logging.exception('Failed on pre-execution callback using {}'.format(call_back_function)) + + +def on_post_execution(action_log_params: ActionLogParams) -> None: + """ + Calls callbacks after execution. As it's being called after execution, it can capture most of fields in + amundsen_application.log.action_log_model.ActionLogParams. Note that any exception from callback will be logged + but won't be propagated. + :param kwargs: + :return: None + """ + LOGGER.debug("Calling callbacks: {}".format(__post_exec_callbacks)) + for call_back_function in __post_exec_callbacks: + try: + call_back_function(action_log_params) + except Exception: + logging.exception('Failed on post-execution callback using {}'.format(call_back_function)) + + +def logging_action_log(action_log_params: ActionLogParams) -> None: + """ + An action logger callback that just logs the ActionLogParams that it receives. + :param **kwargs keyword arguments + :return: None + """ + if LOGGER.isEnabledFor(logging.DEBUG): + LOGGER.debug('logging_action_log: {}'.format(action_log_params)) + + +def register_action_logs() -> None: + """ + Retrieve declared action log callbacks from entry point where there are two groups that can be registered: + 1. "action_log.post_exec.plugin": callback for pre-execution + 2. "action_log.pre_exec.plugin": callback for post-execution + :return: None + """ + for entry_point in iter_entry_points(group='action_log.post_exec.plugin', name=None): + print('Registering post_exec action_log entry_point: {}'.format(entry_point), file=sys.stderr) + register_post_exec_callback(entry_point.load()) + + for entry_point in iter_entry_points(group='action_log.pre_exec.plugin', name=None): + print('Registering pre_exec action_log entry_point: {}'.format(entry_point), file=sys.stderr) + register_pre_exec_callback(entry_point.load()) + + +register_action_logs() diff --git a/frontend/amundsen_application/log/action_log_model.py b/frontend/amundsen_application/log/action_log_model.py new file mode 100644 index 0000000000..c63c367adc --- /dev/null +++ b/frontend/amundsen_application/log/action_log_model.py @@ -0,0 +1,43 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + + +class ActionLogParams(object): + """ + Holds parameters for Action log + """ + + def __init__(self, *, + command: str, + start_epoch_ms: int, + end_epoch_ms: Optional[int] = None, + user: str, + host_name: str, + pos_args_json: str, + keyword_args_json: str, + output: Any = None, + error: Optional[Exception] = None) -> None: + self.command = command + self.start_epoch_ms = start_epoch_ms + self.end_epoch_ms = end_epoch_ms + self.user = user + self.host_name = host_name + self.pos_args_json = pos_args_json + self.keyword_args_json = keyword_args_json + self.output = output + self.error = error + + def __repr__(self) -> str: + return 'ActionLogParams(command={!r}, start_epoch_ms={!r}, end_epoch_ms={!r}, user={!r}, ' \ + 'host_name={!r}, pos_args_json={!r}, keyword_args_json={!r}, output={!r}, error={!r})'\ + .format(self.command, + self.start_epoch_ms, + self.end_epoch_ms, + self.user, + self.host_name, + self.pos_args_json, + self.keyword_args_json, + self.output, + self.error) diff --git a/frontend/amundsen_application/models/__init__.py b/frontend/amundsen_application/models/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/models/announcements.py b/frontend/amundsen_application/models/announcements.py new file mode 100644 index 0000000000..75b4d78d52 --- /dev/null +++ b/frontend/amundsen_application/models/announcements.py @@ -0,0 +1,39 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from marshmallow import Schema, fields, post_dump +from marshmallow.exceptions import ValidationError + +from typing import Dict, List, Any + + +class Post: + def __init__(self, date: str, title: str, html_content: str) -> None: + self.date = date + self.html_content = html_content + self.title = title + + +class PostSchema(Schema): + date = fields.Str(required=True) + title = fields.Str(required=True) + html_content = fields.Str(required=True) + + +class Announcements: + def __init__(self, posts: List = []) -> None: + self.posts = posts + + +class AnnouncementsSchema(Schema): + posts = fields.Nested(PostSchema, many=True) + + @post_dump + def validate_data(self, data: Dict, **kwargs: Any) -> Dict: + posts = data.get('posts', []) + for post in posts: + if post.get('date') is None: + raise ValidationError('All posts must have a date') + if post.get('title') is None: + raise ValidationError('All posts must have a title') + return data diff --git a/frontend/amundsen_application/models/data_issue.py b/frontend/amundsen_application/models/data_issue.py new file mode 100644 index 0000000000..d02c205eaa --- /dev/null +++ b/frontend/amundsen_application/models/data_issue.py @@ -0,0 +1,59 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum +from typing import Optional + + +class Priority(Enum): + P0 = ('P0', 'Blocker') + P1 = ('P1', 'Critical') + P2 = ('P2', 'Major') + P3 = ('P3', 'Minor') + + def __init__(self, level: str, jira_severity: str): + self.level = level + self.jira_severity = jira_severity + + # JIRA SDK does not return priority beyond the name + @staticmethod + def from_jira_severity(jira_severity: str) -> 'Optional[Priority]': + jira_severity_to_priority = { + p.jira_severity: p for p in Priority + } + + return jira_severity_to_priority.get(jira_severity) + + @staticmethod + def from_level(level: str) -> 'Optional[Priority]': + level_to_priority = { + p.level: p for p in Priority + } + + return level_to_priority.get(level) + + @staticmethod + def get_jira_severity_from_level(level: str) -> str: + return Priority[level].jira_severity + + +class DataIssue: + def __init__(self, + issue_key: str, + title: str, + url: str, + status: str, + priority: Optional[Priority]) -> None: + self.issue_key = issue_key + self.title = title + self.url = url + self.status = status + self.priority = priority + + def serialize(self) -> dict: + return {'issue_key': self.issue_key, + 'title': self.title, + 'url': self.url, + 'status': self.status, + 'priority_name': self.priority.jira_severity.lower() if self.priority else None, + 'priority_display_name': self.priority.level if self.priority else None} diff --git a/frontend/amundsen_application/models/issue_results.py b/frontend/amundsen_application/models/issue_results.py new file mode 100644 index 0000000000..f2463133f8 --- /dev/null +++ b/frontend/amundsen_application/models/issue_results.py @@ -0,0 +1,38 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from amundsen_application.models.data_issue import DataIssue +from typing import Dict, List, Optional + + +class IssueResults: + def __init__(self, + issues: List[DataIssue], + total: int, + all_issues_url: str, + open_issues_url: Optional[str] = '', + closed_issues_url: Optional[str] = '', + open_count: Optional[int] = 0) -> None: + """ + Returns an object representing results from an issue tracker. + :param issues: Issues in the issue tracker matching the requested table + :param total: How many issues in all are associated with this table + :param all_issues_url: url to the all issues in the issue tracker + :param open_issues_url: url to the open issues in the issue tracker + :param closed_issues_url: url to the closed issues in the issue tracker + :param open_count: How many open issues are associated with this table + """ + self.issues = issues + self.total = total + self.all_issues_url = all_issues_url + self.open_issues_url = open_issues_url + self.closed_issues_url = closed_issues_url + self.open_count = open_count + + def serialize(self) -> Dict: + return {'issues': [issue.serialize() for issue in self.issues], + 'total': self.total, + 'all_issues_url': self.all_issues_url, + 'open_issues_url': self.open_issues_url, + 'closed_issues_url': self.closed_issues_url, + 'open_count': self.open_count} diff --git a/frontend/amundsen_application/models/notice.py b/frontend/amundsen_application/models/notice.py new file mode 100644 index 0000000000..a945ed6d8e --- /dev/null +++ b/frontend/amundsen_application/models/notice.py @@ -0,0 +1,14 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + + +@dataclass +class ResourceNotice: + """ + An object representing a notice to be displayed about a particular data resource (e.g. table or dashboard). + """ + severity: int + message: str + details: dict diff --git a/frontend/amundsen_application/models/preview_data.py b/frontend/amundsen_application/models/preview_data.py new file mode 100644 index 0000000000..d7ed7d8f60 --- /dev/null +++ b/frontend/amundsen_application/models/preview_data.py @@ -0,0 +1,29 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from marshmallow import Schema, fields, EXCLUDE +from typing import List, Optional + + +class ColumnItem: + def __init__(self, column_name: Optional[str] = None, column_type: Optional[str] = None) -> None: + self.column_name = column_name + self.column_type = column_type + + +class ColumnItemSchema(Schema): + column_name = fields.Str() + column_type = fields.Str() + + +class PreviewData: + def __init__(self, columns: List = [], data: List = [], error_text: str = '') -> None: + self.columns = columns + self.data = data + self.error_text = error_text + + +class PreviewDataSchema(Schema): + columns = fields.Nested(ColumnItemSchema, many=True, unknown=EXCLUDE) + data = fields.List(fields.Dict, many=True) + error_text = fields.Str() diff --git a/frontend/amundsen_application/models/quality.py b/frontend/amundsen_application/models/quality.py new file mode 100644 index 0000000000..005df5fd2a --- /dev/null +++ b/frontend/amundsen_application/models/quality.py @@ -0,0 +1,22 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import attr +from marshmallow3_annotations.ext.attrs import AttrsSchema + + +@attr.s(auto_attribs=True, kw_only=True) +class TableQualityChecksSummary: + num_checks_success: int = attr.ib() + num_checks_failed: int = attr.ib() + num_checks_total: int = attr.ib() + external_url: str = attr.ib() + last_run_timestamp: Optional[int] = attr.ib() + + +class TableQualityChecksSchema(AttrsSchema): + class Meta: + target = TableQualityChecksSummary + register_as_schema = True diff --git a/frontend/amundsen_application/models/user.py b/frontend/amundsen_application/models/user.py new file mode 100644 index 0000000000..f64a17c2a5 --- /dev/null +++ b/frontend/amundsen_application/models/user.py @@ -0,0 +1,38 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Optional + +from amundsen_common.models.user import UserSchema, User +from flask import current_app as app +from marshmallow import ValidationError + + +def _str_no_value(s: Optional[str]) -> bool: + # Returns True if the given string is None or empty + if not s: + return True + if len(s.strip()) == 0: + return True + return False + + +def load_user(user_data: Dict) -> User: + try: + schema = UserSchema() + # In order to call 'GET_PROFILE_URL' we make sure the user id exists + if _str_no_value(user_data.get('user_id')): + user_data['user_id'] = user_data.get('email') + # Add profile_url from optional 'GET_PROFILE_URL' configuration method. + # This methods currently exists for the case where the 'profile_url' is not included + # in the user metadata. + if _str_no_value(user_data.get('profile_url')) and app.config['GET_PROFILE_URL']: + user_data['profile_url'] = app.config['GET_PROFILE_URL'](user_data['user_id']) + return schema.load(user_data) + except ValidationError as err: + return err.messages + + +def dump_user(user: User) -> Dict: + schema = UserSchema() + return schema.dump(user) diff --git a/frontend/amundsen_application/oidc_config.py b/frontend/amundsen_application/oidc_config.py new file mode 100644 index 0000000000..414480c8c3 --- /dev/null +++ b/frontend/amundsen_application/oidc_config.py @@ -0,0 +1,42 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 +import json + +from typing import Dict, Optional +from flask import Flask, session +from amundsen_application.config import LocalConfig +from amundsen_application.models.user import load_user, User + + +def get_access_headers(app: Flask) -> Optional[Dict]: + """ + Function to retrieve and format the Authorization Headers + that can be passed to various microservices who are expecting that. + :param oidc: OIDC object having authorization information + :return: A formatted dictionary containing access token + as Authorization header. + """ + try: + # noinspection PyUnresolvedReferences + access_token = json.dumps(app.auth_client.token) + return {'Authorization': 'Bearer {}'.format(access_token)} + except Exception: + return {} + + +def get_auth_user(app: Flask) -> User: + """ + Retrieves the user information from oidc token, and then makes + a dictionary 'UserInfo' from the token information dictionary. + We need to convert it to a class in order to use the information + in the rest of the Amundsen application. + :param app: The instance of the current app. + :return: A class UserInfo (Note, there isn't a UserInfo class, so we use Any) + """ + user_info = load_user(session.get("user")) + return user_info + + +class OidcConfig(LocalConfig): + AUTH_USER_METHOD = get_auth_user + REQUEST_HEADERS_METHOD = get_access_headers diff --git a/frontend/amundsen_application/proxy/__init__.py b/frontend/amundsen_application/proxy/__init__.py new file mode 100644 index 0000000000..f3145d75b3 --- /dev/null +++ b/frontend/amundsen_application/proxy/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/frontend/amundsen_application/proxy/issue_tracker_clients/__init__.py b/frontend/amundsen_application/proxy/issue_tracker_clients/__init__.py new file mode 100644 index 0000000000..28e224389b --- /dev/null +++ b/frontend/amundsen_application/proxy/issue_tracker_clients/__init__.py @@ -0,0 +1,46 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from flask import current_app as app +from threading import Lock +from werkzeug.utils import import_string + +from amundsen_application.base.base_issue_tracker_client import BaseIssueTrackerClient + +_issue_tracker_client = None +_issue_tracker_client_lock = Lock() + + +def get_issue_tracker_client() -> BaseIssueTrackerClient: + """ + Provides singleton proxy client based on the config + :return: Proxy instance of any subclass of BaseProxy + """ + global _issue_tracker_client + + if _issue_tracker_client: + return _issue_tracker_client + + with _issue_tracker_client_lock: + if _issue_tracker_client: + return _issue_tracker_client + else: + # Gather all the configuration to create an IssueTrackerClient + if app.config['ISSUE_TRACKER_CLIENT_ENABLED']: + url = app.config['ISSUE_TRACKER_URL'] + user = app.config['ISSUE_TRACKER_USER'] + password = app.config['ISSUE_TRACKER_PASSWORD'] + project_id = app.config['ISSUE_TRACKER_PROJECT_ID'] + max_results = app.config['ISSUE_TRACKER_MAX_RESULTS'] + issue_labels = app.config['ISSUE_LABELS'] + + if app.config['ISSUE_TRACKER_CLIENT']: + client = import_string(app.config['ISSUE_TRACKER_CLIENT']) + _issue_tracker_client = client(issue_labels=issue_labels, + issue_tracker_url=url, + issue_tracker_user=user, + issue_tracker_password=password, + issue_tracker_project_id=project_id, + issue_tracker_max_results=max_results) + + return _issue_tracker_client diff --git a/frontend/amundsen_application/proxy/issue_tracker_clients/asana_client.py b/frontend/amundsen_application/proxy/issue_tracker_clients/asana_client.py new file mode 100644 index 0000000000..663e382057 --- /dev/null +++ b/frontend/amundsen_application/proxy/issue_tracker_clients/asana_client.py @@ -0,0 +1,194 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import asana +import logging +from typing import Any, Dict, List + +from amundsen_application.base.base_issue_tracker_client import BaseIssueTrackerClient +from amundsen_application.models.data_issue import DataIssue, Priority +from amundsen_application.models.issue_results import IssueResults + + +class AsanaClient(BaseIssueTrackerClient): + + def __init__(self, issue_labels: List[str], + issue_tracker_url: str, + issue_tracker_user: str, + issue_tracker_password: str, + issue_tracker_project_id: int, + issue_tracker_max_results: int) -> None: + self.issue_labels = issue_labels + self.asana_url = issue_tracker_url + self.asana_user = issue_tracker_user + self.asana_password = issue_tracker_password + self.asana_max_results = issue_tracker_max_results + + self.asana_project_gid = issue_tracker_project_id + self.asana_client = asana.Client.access_token(issue_tracker_password) + + asana_project = self.asana_client.projects.get_project(self.asana_project_gid) + self.asana_workspace_gid = asana_project['workspace']['gid'] + + self._setup_custom_fields() + + def get_issues(self, table_uri: str) -> IssueResults: + """ + :param table_uri: Table Uri ie databasetype://database/table + :return: Metadata of matching issues + """ + + table_parent_task_gid = self._get_parent_task_gid_for_table_uri(table_uri) + + tasks = list(self.asana_client.tasks.get_subtasks_for_task( + table_parent_task_gid, + { + 'opt_fields': [ + 'name', 'completed', 'notes', 'custom_fields', + ] + } + )) + + return IssueResults( + issues=[ + self._asana_task_to_amundsen_data_issue(task) for task in tasks + ], + total=len(tasks), + all_issues_url=self._task_url(table_parent_task_gid) + ) + + def create_issue(self, + table_uri: str, + title: str, + description: str, + priority_level: str, + table_url: str, + **kwargs: Any) -> DataIssue: + """ + Creates an issue in Asana + :param description: Description of the Asana issue + :param priority_level: Priority level for the ticket + :param table_uri: Table Uri ie databasetype://database/table + :param title: Title of the Asana ticket + :param table_url: Link to access the table + :return: Metadata about the newly created issue + """ + + table_parent_task_gid = self._get_parent_task_gid_for_table_uri(table_uri) + enum_value = next(opt for opt in self.priority_field_enum_options if opt['name'] == priority_level) + + return self._asana_task_to_amundsen_data_issue( + self.asana_client.tasks.create_subtask_for_task( + table_parent_task_gid, + { + 'name': title, + 'notes': description + f'\n Table URL: {table_url}', + 'custom_fields': {self.priority_field_gid: enum_value['gid']} + } + ) + ) + + def _setup_custom_fields(self) -> None: + TABLE_URI_FIELD_NAME = 'Table URI (Amundsen)' + PRIORITY_FIELD_NAME = 'Priority (Amundsen)' + + custom_fields = \ + self.asana_client.custom_field_settings.get_custom_field_settings_for_project( + self.asana_project_gid + ) + + custom_fields = {f['custom_field']['name']: f['custom_field'] for f in custom_fields} + + if TABLE_URI_FIELD_NAME in custom_fields: + table_uri_field = custom_fields[TABLE_URI_FIELD_NAME] + else: + table_uri_field = self.asana_client.custom_fields.create_custom_field({ + 'workspace': self.asana_workspace_gid, + 'name': TABLE_URI_FIELD_NAME, + 'format': 'custom', + 'resource_subtype': 'text', + }) + + self.asana_client.projects.add_custom_field_setting_for_project( + self.asana_project_gid, + { + 'custom_field': table_uri_field['gid'], + 'is_important': True, + } + ) + + if PRIORITY_FIELD_NAME in custom_fields: + priority_field = custom_fields[PRIORITY_FIELD_NAME] + else: + priority_field = self.asana_client.custom_fields.create_custom_field({ + 'workspace': self.asana_workspace_gid, + 'name': PRIORITY_FIELD_NAME, + 'format': 'custom', + 'resource_subtype': 'enum', + 'enum_options': [ + { + 'name': p.level + } for p in Priority + ] + }) + + self.asana_client.projects.add_custom_field_setting_for_project( + self.asana_project_gid, + { + 'custom_field': priority_field['gid'], + 'is_important': True, + } + ) + + self.table_uri_field_gid = table_uri_field['gid'] + self.priority_field_gid = priority_field['gid'] + self.priority_field_enum_options = priority_field['enum_options'] + + def _get_parent_task_gid_for_table_uri(self, table_uri: str) -> str: + table_parent_tasks = list(self.asana_client.tasks.search_tasks_for_workspace( + self.asana_workspace_gid, + { + 'projects.any': [self.asana_project_gid], + 'custom_fields.{}.value'.format(self.table_uri_field_gid): table_uri, + } + )) + + # Create the parent task if it doesn't exist. + if len(table_parent_tasks) == 0: + table_parent_task = self.asana_client.tasks.create_task({ + 'name': table_uri, + 'custom_fields': { + self.table_uri_field_gid: table_uri, + }, + 'projects': [self.asana_project_gid], + }) + + return table_parent_task['gid'] + else: + if len(table_parent_tasks) > 1: + logging.warn('There are currently two tasks with the name "{}"'.format(table_uri)) + + return table_parent_tasks[0]['gid'] + + def _task_url(self, task_gid: str) -> str: + return 'https://app.asana.com/0/{project_gid}/{task_gid}'.format( + project_gid=self.asana_project_gid, task_gid=task_gid + ) + + def _asana_task_to_amundsen_data_issue(self, task: Dict) -> DataIssue: + custom_fields = {f['gid']: f for f in task['custom_fields']} + priority_field = custom_fields[self.priority_field_gid] + + priority = None + if priority_field.get('enum_value'): + priority = Priority.from_level(priority_field['enum_value']['name']) + else: + priority = Priority.P3 + + return DataIssue( + issue_key=task['gid'], + title=task['name'], + url=self._task_url(task['gid']), + status='closed' if task['completed'] else 'open', + priority=priority, + ) diff --git a/frontend/amundsen_application/proxy/issue_tracker_clients/issue_exceptions.py b/frontend/amundsen_application/proxy/issue_tracker_clients/issue_exceptions.py new file mode 100644 index 0000000000..416a4c27a7 --- /dev/null +++ b/frontend/amundsen_application/proxy/issue_tracker_clients/issue_exceptions.py @@ -0,0 +1,9 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + + +class IssueConfigurationException(Exception): + """ + Raised when there are missing configuration settings + """ + pass diff --git a/frontend/amundsen_application/proxy/issue_tracker_clients/jira_client.py b/frontend/amundsen_application/proxy/issue_tracker_clients/jira_client.py new file mode 100644 index 0000000000..a42db9dcaa --- /dev/null +++ b/frontend/amundsen_application/proxy/issue_tracker_clients/jira_client.py @@ -0,0 +1,355 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from http import HTTPStatus +from jira import JIRA, JIRAError, Issue, User as JiraUser +from typing import Any, List + +from flask import current_app as app + +from amundsen_application.api.metadata.v0 import USER_ENDPOINT +from amundsen_application.api.utils.request_utils import request_metadata +from amundsen_application.base.base_issue_tracker_client import BaseIssueTrackerClient +from amundsen_application.proxy.issue_tracker_clients.issue_exceptions import IssueConfigurationException +from amundsen_application.models.data_issue import DataIssue, Priority +from amundsen_application.models.issue_results import IssueResults +from amundsen_application.models.user import load_user +from amundsen_common.models.user import User + +import urllib.parse +import logging + +SEARCH_STUB_ALL_ISSUES = ('text ~ "\\"Table Key: {table_key} [PLEASE DO NOT REMOVE]\\"" ' + 'and (resolution = unresolved or (resolution != unresolved and updated > -30d)) ' + 'order by resolution DESC, priority DESC, createdDate DESC') +SEARCH_STUB_OPEN_ISSUES = ('text ~ "\\"Table Key: {table_key} [PLEASE DO NOT REMOVE]\\"" ' + 'and resolution = unresolved ' + 'order by priority DESC, createdDate DESC') +SEARCH_STUB_CLOSED_ISSUES = ('text ~ "\\"Table Key: {table_key} [PLEASE DO NOT REMOVE]\\"" ' + 'and resolution != unresolved ' + 'order by priority DESC, createdDate DESC') +# this is provided by jira as the type of a bug +ISSUE_TYPE_ID = 1 +ISSUE_TYPE_NAME = 'Bug' + + +class JiraClient(BaseIssueTrackerClient): + + def __init__(self, issue_labels: List[str], + issue_tracker_url: str, + issue_tracker_user: str, + issue_tracker_password: str, + issue_tracker_project_id: int, + issue_tracker_max_results: int) -> None: + self.issue_labels = issue_labels + self.jira_url = issue_tracker_url + self.jira_user = issue_tracker_user + self.jira_password = issue_tracker_password + self.jira_project_id = issue_tracker_project_id + self.jira_max_results = issue_tracker_max_results + self._validate_jira_configuration() + self.jira_client = self.get_client() + + def get_client(self) -> JIRA: + """ + Get the Jira client properly formatted prepared for hitting JIRA + :return: A Jira client. + """ + return JIRA( + server=self.jira_url, + basic_auth=(self.jira_user, self.jira_password) + ) + + def get_issues(self, table_uri: str) -> IssueResults: + """ + Runs a query against a given Jira project for tickets matching the key + Returns open issues sorted by most recently created. + :param table_uri: Table Uri ie databasetype://database/table + :return: Metadata of matching issues + """ + try: + issues = self.jira_client.search_issues(SEARCH_STUB_ALL_ISSUES.format( + table_key=table_uri), + maxResults=self.jira_max_results) + + # Call search_issues for only 1 open/closed issue just to get the total values from the response. The + # total count from all issues may not be accurate if older closed issues are excluded from the response + open_issues = self.jira_client.search_issues(SEARCH_STUB_OPEN_ISSUES.format( + table_key=table_uri), + maxResults=1) + closed_issues = self.jira_client.search_issues(SEARCH_STUB_CLOSED_ISSUES.format( + table_key=table_uri), + maxResults=1) + + returned_issues = self._sort_issues(issues) + return IssueResults(issues=returned_issues, + total=open_issues.total + closed_issues.total, + all_issues_url=self._generate_issues_url(SEARCH_STUB_ALL_ISSUES, + table_uri, + open_issues.total + closed_issues.total), + open_issues_url=self._generate_issues_url(SEARCH_STUB_OPEN_ISSUES, + table_uri, + open_issues.total), + closed_issues_url=self._generate_issues_url(SEARCH_STUB_CLOSED_ISSUES, + table_uri, + closed_issues.total), + open_count=open_issues.total) + except JIRAError as e: + logging.exception(str(e)) + raise e + + def create_issue(self, + table_uri: str, + title: str, + description: str, + priority_level: str, + table_url: str, + **kwargs: Any) -> DataIssue: + """ + Creates an issue in Jira + :param description: Description of the Jira issue + :param priority_level: Priority level for the ticket + :param table_uri: Table Uri ie databasetype://database/table + :param title: Title of the Jira ticket + :param table_url: Link to access the table + :param owner_ids: List of table owners user ids + :param frequent_user_ids: List of table frequent users user ids + :param project_key: Jira project key to specify where the ticket should be created + :return: Metadata about the newly created issue + """ + try: + if app.config['AUTH_USER_METHOD']: + user_email = app.config['AUTH_USER_METHOD'](app).email + # We currently cannot use the email directly because of the following issue: + # https://community.atlassian.com/t5/Answers-Developer-Questions/JIRA-Rest-API-find-JIRA-user-based-on-user-s-email-address/qaq-p/532715 + jira_id = user_email.split('@')[0] + else: + raise Exception('AUTH_USER_METHOD must be configured to set the JIRA issue reporter') + + reporter = {'name': jira_id} + + # Detected by the jira client based on API version & deployment. + if self.jira_client.deploymentType == 'Cloud': + try: + user = self.jira_client._fetch_pages(JiraUser, None, "user/search", 0, 1, {'query': user_email})[0] + reporter = {'accountId': user.accountId} + except IndexError: + raise Exception('Could not find the reporting user in our Jira installation.') + + issue_type_id = ISSUE_TYPE_ID + if app.config['ISSUE_TRACKER_ISSUE_TYPE_ID']: + issue_type_id = app.config['ISSUE_TRACKER_ISSUE_TYPE_ID'] + + project_key = kwargs.get('project_key', None) + proj_key = 'key' if project_key else 'id' + proj_value = project_key if project_key else self.jira_project_id + + reporting_user = self._get_users_from_ids([user_email]) + owners = self._get_users_from_ids(kwargs.get('owner_ids', [])) + frequent_users = self._get_users_from_ids(kwargs.get('frequent_user_ids', [])) + + reporting_user_str = self._generate_reporting_user_str(reporting_user) + owners_description_str = self._generate_owners_description_str(owners) + frequent_users_description_str = self._generate_frequent_users_description_str(frequent_users) + all_users_description_str = self._generate_all_table_users_description_str(owners_description_str, + frequent_users_description_str) + + issue = self.jira_client.create_issue(fields=dict(project={ + proj_key: proj_value + }, issuetype={ + 'id': issue_type_id, + 'name': ISSUE_TYPE_NAME, + }, labels=self.issue_labels, + summary=title, + description=(f'{description} ' + f'\n *Reported By:* {reporting_user_str if reporting_user_str else user_email} ' + f'\n *Table Key:* {table_uri} [PLEASE DO NOT REMOVE] ' + f'\n *Table URL:* {table_url} ' + f'{all_users_description_str}'), + priority={ + 'name': Priority.get_jira_severity_from_level(priority_level) + }, reporter=reporter)) + + self._add_watchers_to_issue(issue_key=issue.key, users=owners + frequent_users) + + return self._get_issue_properties(issue=issue) + except JIRAError as e: + logging.exception(str(e)) + raise e + + def _validate_jira_configuration(self) -> None: + """ + Validates that all properties for jira configuration are set. Returns a list of missing properties + to return if they are missing + :return: String representing missing Jira properties, or an empty string. + """ + missing_fields = [] + if not self.jira_url: + missing_fields.append('ISSUE_TRACKER_URL') + if not self.jira_user: + missing_fields.append('ISSUE_TRACKER_USER') + if not self.jira_password: + missing_fields.append('ISSUE_TRACKER_PASSWORD') + if not self.jira_project_id: + missing_fields.append('ISSUE_TRACKER_PROJECT_ID') + if not self.jira_max_results: + missing_fields.append('ISSUE_TRACKER_MAX_RESULTS') + + if missing_fields: + raise IssueConfigurationException( + f'The following config settings must be set for Jira: {", ".join(missing_fields)} ') + + @staticmethod + def _get_issue_properties(issue: Issue) -> DataIssue: + """ + Maps the jira issue object to properties we want in the UI + :param issue: Jira issue to map + :return: JiraIssue + """ + return DataIssue(issue_key=issue.key, + title=issue.fields.summary, + url=issue.permalink(), + status=issue.fields.status.name, + priority=Priority.from_jira_severity(issue.fields.priority.name)) + + def _generate_issues_url(self, search_stub: str, table_uri: str, issueCount: int) -> str: + """ + Way to get list of jira tickets + SDK doesn't return a query + :param search_stub: search stub for type of query to build + :param table_uri: table uri from the ui + :param issueCount: number of jira issues associated to the search + :return: url to a list of issues in jira + """ + if issueCount == 0: + return '' + search_query = urllib.parse.quote(search_stub.format(table_key=table_uri)) + return f'{self.jira_url}/issues/?jql={search_query}' + + def _sort_issues(self, issues: List[Issue]) -> List[DataIssue]: + """ + Sorts issues by resolution, first by unresolved and then by resolved. Also maps the issues to + the object used by the front end. Doesn't include closed issues that are older than 30 days. + :param issues: Issues returned from the JIRA API + :return: List of data issues + """ + open = [] + closed = [] + for issue in issues: + data_issue = self._get_issue_properties(issue) + if not issue.fields.resolution: + open.append(data_issue) + else: + closed.append(data_issue) + return open + closed + + @staticmethod + def _get_users_from_ids(user_ids: List[str]) -> List[User]: + """ + Calls get_user metadata API with a user id to retrieve user details. + :param user_ids: List of strings representing user ids + :return: List of User objects + """ + users = [] + for user_id in user_ids: + url = '{0}{1}/{2}'.format(app.config['METADATASERVICE_BASE'], USER_ENDPOINT, user_id) + response = request_metadata(url=url) + if response.status_code == HTTPStatus.OK: + user = load_user(response.json()) + if user: + users.append(user) + return users + + def _generate_reporting_user_str(self, reporting_user: List[User]) -> str: + """ + :param reporting_user: List containing a user representing the reporter of the issue + or an empty list if the reporter's information could not be retrieved + :return: String of reporting user's information to display in the description + """ + if not reporting_user: + return '' + user = reporting_user[0] + if user.is_active and user.profile_url: + return (f'[{user.full_name if user.full_name else user.email}' + f'|{user.profile_url}]') + else: + return user.email if user.email is not None else '' + + def _generate_owners_description_str(self, owners: List[User]) -> str: + """ + Build a list of table owner information to add to the description of the ticket + :param owners: List of users representing owners of the table + :return: String of owners to append in the description + """ + owners_description_str = '\n Table Owners:' if owners else '' + user_details_list = [] + inactive_user_details_list = [] + for user in owners: + if user.is_active and user.profile_url: + user_details_list.append((f'[{user.full_name if user.full_name else user.email}' + f'|{user.profile_url}] ')) + continue + else: + inactive_user_details = f'{user.full_name if user.full_name else user.email}' + + # Append relevant alumni and manager information if the user is a person and inactive + if not user.is_active and user.full_name: + inactive_user_details += ' (Alumni) ' + if user.manager_fullname: + inactive_user_details += f'\u2022 Manager: {user.manager_fullname} ' + inactive_user_details_list.append(inactive_user_details) + return '\n '.join(filter(None, [owners_description_str, + '\n '.join(user_details_list), + '\n '.join(inactive_user_details_list)])) + + def _generate_frequent_users_description_str(self, frequent_users: List[User]) -> str: + """ + Build a list of table frequent user information to add to the description of the ticket; this list will leave + out inactive frequent users + :param frequent_users: List of users representing frequent users of the table + :return: String of frequent users to append in the description + """ + frequent_users_description_str = '\n Frequent Users: ' if frequent_users else '' + user_details_list = [] + for user in frequent_users: + if user.is_active and user.profile_url: + user_details_list.append((f'[{user.full_name if user.full_name else user.email}' + f'|{user.profile_url}]')) + return frequent_users_description_str + ', '.join(user_details_list) if user_details_list else '' + + def _generate_all_table_users_description_str(self, owners_str: str, frequent_users_str: str) -> str: + """ + Takes the generated owners and frequent users information and packages it up into one string for appending + to the ticket description + :param owners_str: Owner information + :param frequent_users_str: Frequent user information + :return: String including all table users (owners and frequent users) information to append to the description + """ + table_users_description_title = '' + if owners_str and frequent_users_str: + table_users_description_title = '\n\n *Owners and Frequent Users (added as Watchers):* ' + elif owners_str: + table_users_description_title = '\n\n *Owners (added as Watchers):* ' + elif frequent_users_str: + table_users_description_title = '\n\n *Frequent Users (added as Watchers):* ' + return table_users_description_title + owners_str + frequent_users_str + + def _add_watchers_to_issue(self, issue_key: str, users: List[User]) -> None: + """ + Given an issue key and a list of users, add those users as watchers to the issue if they are active + :param issue_key: key representing an issue + :param users: list of users to add as watchers to the issue + """ + for user in users: + if user.is_active: + try: + # Detected by the jira client based on API version & deployment. + if self.jira_client.deploymentType == 'Cloud': + jira_user = self.jira_client._fetch_pages(JiraUser, None, "user/search", 0, 1, + {'query': user.email})[0] + self.jira_client.add_watcher(issue=issue_key, watcher=jira_user.accountId) + elif user.email is not None: + self.jira_client.add_watcher(issue=issue_key, watcher=user.email.split("@")[0]) + except (JIRAError, IndexError): + logging.warning('Could not add user {user_email} as a watcher on the issue.' + .format(user_email=user.email)) diff --git a/frontend/amundsen_application/sessions.db b/frontend/amundsen_application/sessions.db new file mode 100644 index 0000000000..3f2355a043 Binary files /dev/null and b/frontend/amundsen_application/sessions.db differ diff --git a/frontend/amundsen_application/static/css/_animations.scss b/frontend/amundsen_application/static/css/_animations.scss new file mode 100644 index 0000000000..37efaf64c6 --- /dev/null +++ b/frontend/amundsen_application/static/css/_animations.scss @@ -0,0 +1,62 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +$loading-duration: 1s; +$loading-curve: cubic-bezier(0.45, 0, 0.15, 1); +$pulse-duration: 1.5s; +$pulse-easing: linear; + +@keyframes pulse { + 0% { + transform: scale(1.1); + } + + 50% { + transform: scale(0.8); + } + + 100% { + transform: scale(1); + } +} + +%is-pulse-animated { + animation: pulse $pulse-duration $pulse-easing infinite; + background-repeat: no-repeat; + background-position: center; + background-size: contain; +} + +.is-pulse-animated { + @extend %is-pulse-animated; +} + +@keyframes shimmer { + 0% { + background-position: 100% 0; + } + + 100% { + background-position: 0 0; + } +} + +%is-shimmer-animated { + animation: $loading-duration shimmer $loading-curve infinite; + background-image: linear-gradient( + to right, + $gray10 0%, + $gray10 33%, + $gray5 50%, + $gray10 67%, + $gray10 100% + ); + background-repeat: no-repeat; + background-size: 300% 100%; +} + +.is-shimmer-animated { + @extend %is-shimmer-animated; +} diff --git a/frontend/amundsen_application/static/css/_avatars.scss b/frontend/amundsen_application/static/css/_avatars.scss new file mode 100644 index 0000000000..e4dcdf8b28 --- /dev/null +++ b/frontend/amundsen_application/static/css/_avatars.scss @@ -0,0 +1,6 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +.sb-avatar > img { + margin: 0; +} diff --git a/frontend/amundsen_application/static/css/_bootstrap-custom.scss b/frontend/amundsen_application/static/css/_bootstrap-custom.scss new file mode 100644 index 0000000000..ebe2080157 --- /dev/null +++ b/frontend/amundsen_application/static/css/_bootstrap-custom.scss @@ -0,0 +1,57 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +$icon-font-path: '/static/fonts/bootstrap/'; + +// Bootstrap + Custom variables +@import 'variables'; + +// Core mixins +@import '~bootstrap-sass/assets/stylesheets/bootstrap/mixins'; + +// Reset and dependencies +@import '~bootstrap-sass/assets/stylesheets/bootstrap/normalize'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/print'; + +// Core CSS +@import '~bootstrap-sass/assets/stylesheets/bootstrap/scaffolding'; +// Commenting out as we use a specific component for code highlight +// that collides with these styles +// @import '~bootstrap-sass/assets/stylesheets/bootstrap/code'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/grid'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/tables'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/forms'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/buttons'; + +// Components +@import '~bootstrap-sass/assets/stylesheets/bootstrap/component-animations'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/dropdowns'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/button-groups'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/input-groups'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/navs'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/navbar'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/breadcrumbs'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/pagination'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/pager'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/labels'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/badges'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/jumbotron'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/thumbnails'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/alerts'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/progress-bars'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/media'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/list-group'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/panels'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/responsive-embed'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/wells'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/close'; + +// Components w/ JavaScript +@import '~bootstrap-sass/assets/stylesheets/bootstrap/modals'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/tooltip'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/popovers'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/carousel'; + +// Utility classes +@import '~bootstrap-sass/assets/stylesheets/bootstrap/utilities'; +@import '~bootstrap-sass/assets/stylesheets/bootstrap/responsive-utilities'; diff --git a/frontend/amundsen_application/static/css/_buttons-custom.scss b/frontend/amundsen_application/static/css/_buttons-custom.scss new file mode 100644 index 0000000000..7d570cec3b --- /dev/null +++ b/frontend/amundsen_application/static/css/_buttons-custom.scss @@ -0,0 +1,4 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// This file is intentionally left blank and should be overwritten by in the build process. diff --git a/frontend/amundsen_application/static/css/_buttons-default.scss b/frontend/amundsen_application/static/css/_buttons-default.scss new file mode 100644 index 0000000000..3a8ea3f29f --- /dev/null +++ b/frontend/amundsen_application/static/css/_buttons-default.scss @@ -0,0 +1,225 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +$outline-offset: 2px; +$outline-width: 2px; + +%a11y-outline-styles { + outline: $outline-width solid $brand-color-4; + outline-offset: $outline-offset; + outline-style: auto; + position: relative; + z-index: 10; +} + +.btn { + &.btn-primary, + &.btn-default { + border-width: 2px; + font-weight: $font-weight-body-bold; + padding: 6px $spacer-2; + + img.icon { + border: none; + height: 18px; + margin: 0 $spacer-half 0 0; + -webkit-mask-size: 18px; + mask-size: 18px; + min-width: 18px; + vertical-align: top; + width: 18px; + } + + &.btn-lg { + font-weight: $font-size-large; + height: 48px; + padding: 10px $spacer-2; + + img.icon { + height: $spacer-3; + margin: 0 $spacer-half 0 0; + -webkit-mask-size: $spacer-3; + mask-size: $spacer-3; + min-width: $spacer-3; + width: $spacer-3; + } + } + } + + &.btn-primary { + img.icon { + background-color: $btn-primary-color; + } + + &:not(.disabled):hover, + &:not([disabled]):hover, + &:focus { + background-color: $btn-primary-bg-hover; + border-color: $btn-primary-border-hover; + } + } + + &.btn-default { + img.icon { + background-color: $btn-default-color; + } + + &.muted { + border-color: $divider; + color: $text-secondary; + padding: 0 $spacer-1; + + .icon { + background-color: $text-secondary; + } + } + + &:not(.disabled):hover, + &:not([disabled]):hover, + &:focus { + background-color: $btn-default-bg-hover; + border-color: $btn-default-border-hover; + } + } + + * { + vertical-align: middle; + } + + &.btn-block { + margin-bottom: $spacer-half; + } + + &.btn-flat-icon { + background-color: transparent; + border: none; + box-shadow: none !important; + color: $text-secondary; + padding: 0; + text-align: left; + + &:focus, + &:not(.disabled):hover, + &:not([disabled]):hover { + background-color: transparent; + color: $brand-color-4; + + .icon { + background-color: $brand-color-4; + } + } + } + + &.btn-flat-icon-dark { + background-color: transparent; + border: none; + box-shadow: none !important; + color: $text-secondary; + padding: 0; + text-align: left; + + &:focus, + &:not(.disabled):hover, + &:not([disabled]):hover { + background-color: transparent; + color: $text-primary; + + .icon { + background-color: $text-primary; + } + } + } + + &.btn-nav-bar-icon { + padding: $spacer-1; + display: flex; + align-items: center; + + svg { + fill: $white; + } + + &:hover, + &:focus, + &.is-open { + svg { + fill: $gray20; + } + } + + .is-light & { + svg { + fill: $gray100; + } + + &:hover, + &:focus, + &.is-open { + svg { + fill: $gray70; + } + } + } + } + + &.btn-close { + background-color: $icon-bg; + border: none; + height: 18px; + margin: $spacer-half 0 0; + -webkit-mask-image: url('../images/icons/Close.svg'); + mask-image: url('../images/icons/Close.svg'); + -webkit-mask-position: center; + mask-position: center; + -webkit-mask-size: contain; + -webkit-mask-size: 110%; + mask-size: contain; + mask-size: 110%; + padding: 0; + width: 18px; + + &:focus, + &:not(.disabled):hover, + &:not([disabled]):hover { + background-color: $icon-bg-dark; + } + } + + &.btn-link { + color: $link-color; + text-decoration: none; + padding: $spacer-half 0; + + &:hover, + &:focus { + color: $link-hover-color; + } + } + + &.disabled, + &:disabled { + -webkit-box-shadow: none; + -moz-box-shadow: none; + box-shadow: none; + color: $text-secondary; + pointer-events: none; + + &:hover { + color: $text-secondary; + } + } + + &:focus-visible, + &:active:focus, + &:focus { + @extend %a11y-outline-styles; + } +} + +// Outlines for A11y +a:focus-visible, +a:focus { + @extend %a11y-outline-styles; +} diff --git a/frontend/amundsen_application/static/css/_buttons.scss b/frontend/amundsen_application/static/css/_buttons.scss new file mode 100644 index 0000000000..1c7d422d89 --- /dev/null +++ b/frontend/amundsen_application/static/css/_buttons.scss @@ -0,0 +1,5 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'buttons-default'; +@import 'buttons-custom'; diff --git a/frontend/amundsen_application/static/css/_colors.scss b/frontend/amundsen_application/static/css/_colors.scss new file mode 100644 index 0000000000..ed8e382669 --- /dev/null +++ b/frontend/amundsen_application/static/css/_colors.scss @@ -0,0 +1,267 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +/** + +Avoid using these color definitions directly. +Define semantic variables that reference this color palette instead. +The color palette can be swapped out or modified without +revisiting each individual color usage. + +--------------- +Do this: + $text-primary: $gray100; + + body { + color: $text-primary; + } + +---------------- +Don't do this: + body { + color: $gray100; + } + +*/ + +$white: #fff; +$black: #000; + +/* Red */ +$red0: #fffafb; +$red5: #ffe5e9; +$red10: #ffcfd5; +$red20: #ffa0ac; +$red30: #ff7689; +$red40: #ff516b; +$red50: #ff3354; +$red60: #e6193f; +$red70: #b8072c; +$red80: #8c0020; +$red90: #670019; +$red100: #560015; + +/* Sunset */ +$sunset0: #fffbfa; +$sunset5: #ffe4dd; +$sunset10: #ffccbf; +$sunset20: #ff9e87; +$sunset30: #ff7b5c; +$sunset40: #ff623e; +$sunset50: #ff4e28; +$sunset60: #db3615; +$sunset70: #af230a; +$sunset80: #841604; +$sunset90: #5f0e01; +$sunset100: #4e0b00; + +/* Orange */ +$orange0: #fff6f2; +$orange5: #ffe8dd; +$orange10: #ffd9c7; +$orange20: #ffb38f; +$orange30: #ff915d; +$orange40: #ff7232; +$orange50: #f9560e; +$orange60: #d03d00; +$orange70: #a82e00; +$orange80: #832300; +$orange90: #651a00; +$orange100: #581600; + +/* Amber */ +$amber0: #fffdfa; +$amber5: #fff6e7; +$amber10: #fff0d4; +$amber20: #ffe0a9; +$amber30: #ffd082; +$amber40: #ffc161; +$amber50: #ffb146; +$amber60: #ffa030; +$amber70: #ff8d1f; +$amber80: #fe7e13; +$amber90: #e66909; +$amber100: #cb5803; + +/* Yellow */ +$yellow0: #fffefa; +$yellow5: #fff8d9; +$yellow10: #fff3b8; +$yellow20: #ffe77b; +$yellow30: #ffdd4c; +$yellow40: #ffd32a; +$yellow50: #ffca13; +$yellow60: #ffc002; +$yellow70: #efac00; +$yellow80: #dc9900; +$yellow90: #c78700; +$yellow100: #b07600; + +/* Citron */ +$citron0: #fffff2; +$citron5: #ffffd2; +$citron10: #feffb2; +$citron20: #fbff6f; +$citron30: #f1fb3b; +$citron40: #e2f316; +$citron50: #cce700; +$citron60: #b5d900; +$citron70: #9ac800; +$citron80: #82b400; +$citron90: #6c9c00; +$citron100: #578000; + +/* Lime */ +$lime0: #fdfffa; +$lime5: #edfed0; +$lime10: #d6f3a0; +$lime20: #a4dc48; +$lime30: #75c404; +$lime40: #5eab00; +$lime50: #499300; +$lime60: #347d00; +$lime70: #216800; +$lime80: #155600; +$lime90: #0e4400; +$lime100: #0a3600; + +/* Green */ +$green0: #fafffc; +$green5: #d1ffe2; +$green10: #a8ffc4; +$green20: #4be77a; +$green30: #04cd3d; +$green40: #00b32e; +$green50: #009b22; +$green60: #008316; +$green70: #006e0b; +$green80: #005a05; +$green90: #004802; +$green100: #003901; + +/* Mint */ +$mint0: #fafffd; +$mint5: #d1ffee; +$mint10: #a6fbde; +$mint20: #4ae3ae; +$mint30: #04ca83; +$mint40: #00b16f; +$mint50: #00985d; +$mint60: #00824c; +$mint70: #006c3c; +$mint80: #00592f; +$mint90: #004724; +$mint100: #00381c; + +/* Teal */ +$teal0: #fafffe; +$teal5: #d1fff7; +$teal10: #a8fff4; +$teal20: #4ceae4; +$teal30: #04ced2; +$teal40: #00b0b9; +$teal50: #00949f; +$teal60: #007b85; +$teal70: #00626b; +$teal80: #004c53; +$teal90: #003b40; +$teal100: #003338; + +/* Cyan */ +$cyan0: #fafdff; +$cyan5: #e7f6ff; +$cyan10: #d4f0ff; +$cyan20: #a9e1ff; +$cyan30: #82d2ff; +$cyan40: #5dbcf4; +$cyan50: #3a97d3; +$cyan60: #2277b3; +$cyan70: #135b96; +$cyan80: #09457b; +$cyan90: #043563; +$cyan100: #01284e; + +/* Blue */ +$blue0: #fafbff; +$blue5: #e8ecff; +$blue10: #d5dcff; +$blue20: #acbbff; +$blue30: #869dff; +$blue40: #6686ff; +$blue50: #4b73ff; +$blue60: #3668ff; +$blue70: #2156db; +$blue80: #1242af; +$blue90: #093186; +$blue100: #042260; + +/* Indigo */ +$indigo0: #fafaff; +$indigo5: #ebebff; +$indigo10: #dcdcff; +$indigo20: #babaff; +$indigo30: #9c9bff; +$indigo40: #8481ff; +$indigo50: #726bff; +$indigo60: #665aff; +$indigo70: #604cff; +$indigo80: #523be4; +$indigo90: #3e29b1; +$indigo100: #2b1b81; + +/* Purple */ +$purple0: #fdfaff; +$purple5: #f6ebff; +$purple10: #ecdcff; +$purple20: #d7b8ff; +$purple30: #c294ff; +$purple40: #ad71ff; +$purple50: #9b52ff; +$purple60: #8b37ff; +$purple70: #7b20f9; +$purple80: #590dc4; +$purple90: #420499; +$purple100: #390188; + +/* Pink */ +$pink0: #fffafd; +$pink5: #ffe1f2; +$pink10: #ffc7e4; +$pink20: #ff8fcc; +$pink30: #ff5dbb; +$pink40: #ff32b1; +$pink50: #ff0eb0; +$pink60: #de00a7; +$pink70: #bd00a0; +$pink80: #a00093; +$pink90: #860081; +$pink100: #71006f; + +/* Rose */ +$rose0: #fff2f5; +$rose5: #ffe1e9; +$rose10: #ffcfdc; +$rose20: #ffa0ba; +$rose30: #ff769e; +$rose40: #ff5187; +$rose50: #ff3378; +$rose60: #e51966; +$rose70: #b70752; +$rose80: #8b0040; +$rose90: #660031; +$rose100: #55002a; + +/* Gray */ +$gray0: #fcfcff; +$gray5: #f4f4fa; +$gray10: #e7e7ef; +$gray15: #d8d8e4; +$gray20: #cacad9; +$gray30: #acacc0; +$gray40: #9191a8; +$gray50: #787891; +$gray60: #63637b; +$gray70: #515167; +$gray80: #414155; +$gray90: #334; +$gray100: #292936; diff --git a/frontend/amundsen_application/static/css/_dropdowns.scss b/frontend/amundsen_application/static/css/_dropdowns.scss new file mode 100644 index 0000000000..0902d5ee21 --- /dev/null +++ b/frontend/amundsen_application/static/css/_dropdowns.scss @@ -0,0 +1,39 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; +@import 'typography'; + +.dropdown { + .dropdown-toggle { + box-shadow: none; + } + + .dropdown-menu { + border: 1px solid $stroke; + border-radius: $popover-border-radius; + box-shadow: 0 4px 12px -3px rgba(17, 17, 31, 0.12); + overflow: hidden; + padding: 0; + + li { + &:hover { + background-color: $body-bg-tertiary; + } + + a { + padding: $spacer-1; + + &:hover { + background-color: inherit; + } + } + } + } + + .section-title { + @extend %text-title-w3; + + color: $text-tertiary; + } +} diff --git a/frontend/amundsen_application/static/css/_fonts-custom.scss b/frontend/amundsen_application/static/css/_fonts-custom.scss new file mode 100644 index 0000000000..a940fb5ef6 --- /dev/null +++ b/frontend/amundsen_application/static/css/_fonts-custom.scss @@ -0,0 +1,4 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// This file is intentionally left blank and should be used to add new fonts or overwrite defaults from _fonts-default.scss diff --git a/frontend/amundsen_application/static/css/_fonts-default.scss b/frontend/amundsen_application/static/css/_fonts-default.scss new file mode 100644 index 0000000000..9e39f71e8b --- /dev/null +++ b/frontend/amundsen_application/static/css/_fonts-default.scss @@ -0,0 +1,49 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +// Space Mono +@font-face { + font-family: 'Space Mono'; + font-style: normal; + font-weight: $font-weight-header-regular; + src: url('../fonts/SpaceMono-Regular.ttf') format('truetype'); +} + +// Roboto +@font-face { + font-family: 'Roboto'; + font-style: normal; + font-weight: $font-weight-header-regular; + src: url('../fonts/Roboto-Medium.ttf') format('truetype'); +} + +@font-face { + font-family: 'Roboto'; + font-style: normal; + font-weight: $font-weight-header-bold; + src: url('../fonts/Roboto-Bold.ttf') format('truetype'); +} + +// Open Sans +@font-face { + font-family: 'Open Sans'; + font-style: normal; + font-weight: $font-weight-body-regular; + src: url('../fonts/OpenSans-Regular.ttf') format('truetype'); +} + +@font-face { + font-family: 'Open Sans'; + font-style: normal; + font-weight: $font-weight-body-semi-bold; + src: url('../fonts/OpenSans-SemiBold.ttf') format('truetype'); +} + +@font-face { + font-family: 'Open Sans'; + font-style: normal; + font-weight: $font-weight-body-bold; + src: url('../fonts/OpenSans-Bold.ttf') format('truetype'); +} diff --git a/frontend/amundsen_application/static/css/_fonts.scss b/frontend/amundsen_application/static/css/_fonts.scss new file mode 100644 index 0000000000..640bd42473 --- /dev/null +++ b/frontend/amundsen_application/static/css/_fonts.scss @@ -0,0 +1,7 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// Amundsen Default Fonts +@import 'fonts-default'; +// Per-Client Custom Fonts +@import 'fonts-custom'; diff --git a/frontend/amundsen_application/static/css/_icons-custom.scss b/frontend/amundsen_application/static/css/_icons-custom.scss new file mode 100644 index 0000000000..fd924e2e08 --- /dev/null +++ b/frontend/amundsen_application/static/css/_icons-custom.scss @@ -0,0 +1,4 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// This file is intentionally left blank and should be used to add new icons or overwrite defaults from _icons-default.scss diff --git a/frontend/amundsen_application/static/css/_icons-default.scss b/frontend/amundsen_application/static/css/_icons-default.scss new file mode 100644 index 0000000000..1a17527cad --- /dev/null +++ b/frontend/amundsen_application/static/css/_icons-default.scss @@ -0,0 +1,228 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +$icon-size: 24px; +$icon-small-size: 16px; + +// Icons +// Lookout! When you update one of these, please update the enums on +// ../js/interfaces/Enums.ts +// Map of Database names and icon paths +$data-stores: ( + database: '../images/icons/Database.svg', + hive: '../images/icons/logo-hive.svg', + bigquery: '../images/icons/logo-bigquery.svg', + delta: '../images/icons/logo-delta.png', + dremio: '../images/icons/logo-dremio.svg', + druid: '../images/icons/logo-druid.svg', + oracle: '../images/icons/logo-oracle.svg', + presto: '../images/icons/logo-presto.svg', + trino: '../images/icons/logo-trino.svg', + postgres: '../images/icons/logo-postgres.svg', + redshift: '../images/icons/logo-redshift.svg', + snowflake: '../images/icons/logo-snowflake.svg', + elasticsearch: '../images/icons/logo-elasticsearch.svg', + teradata: '../images/icons/logo-teradata.svg', +); + +// Map of Dashboard names and icon paths +$dashboards: ( + dashboard: '../images/icons/dashboard.svg', + mode: '../images/icons/logo-mode.svg', + redash: '../images/icons/logo-redash.svg', + tableau: '../images/icons/logo-tableau.svg', + superset: '../images/icons/logo-superset.svg', + databricks_sql: '../images/icons/logo-databricks-sql.svg', + powerbi: '../images/icons/logo-powerbi.svg', +); + +// Map of User names and icon paths +$users: ( + users: '../images/icons/users.svg', +); + +$check: ( + check: '../images/icons/check.svg', +); + +// Given a Map of key/value pairs, generates a new class +@mixin iconBackgrounds($map) { + @each $name, $url in $map { + &.icon-#{$name} { + background: transparent url($url) center center / contain no-repeat; + } + } +} + +span.icon { + // Generate Icons + @include iconBackgrounds($data-stores); + @include iconBackgrounds($dashboards); + @include iconBackgrounds($users); + @include iconBackgrounds($check); + + background-color: $icon-bg; + border: none; + display: inline-block; + height: $icon-size; + margin: auto 16px auto 0; + min-width: $icon-size; + vertical-align: middle; + width: $icon-size; +} + +img.icon { + /* DEPRECATED: follow behavior above to generate + icons */ + background-color: $icon-bg; + border: none; + height: $icon-size; + margin: -3px 4px -3px 0; + -webkit-mask-repeat: no-repeat; + mask-repeat: no-repeat; + -webkit-mask-size: contain; + mask-size: contain; + min-width: $icon-size; + width: $icon-size; + + &.icon-small { + height: $icon-small-size; + -webkit-mask-size: $icon-small-size $icon-small-size; + mask-size: $icon-small-size $icon-small-size; + min-width: $icon-small-size; + width: $icon-small-size; + } + + &.icon-color { + background-color: $icon-bg-brand; + } + + &.icon-dark { + background-color: $icon-bg-dark; + } + + &.icon-alert { + -webkit-mask-image: url('../images/icons/Alert-Triangle.svg'); + mask-image: url('../images/icons/Alert-Triangle.svg'); + } + + &.icon-bookmark { + -webkit-mask-image: url('../images/icons/Favorite.svg'); + mask-image: url('../images/icons/Favorite.svg'); + } + + &.icon-bookmark-filled { + -webkit-mask-image: url('../images/icons/Favorite-Filled.svg'); + mask-image: url('../images/icons/Favorite-Filled.svg'); + } + + &.icon-delete { + -webkit-mask-image: url('../images/icons/Trash.svg'); + mask-image: url('../images/icons/Trash.svg'); + } + + &.icon-red-triangle-warning { + -webkit-mask-image: url('../images/icons/DataQualityWarning.svg'); + mask-image: url('../images/icons/DataQualityWarning.svg'); + } + + &.icon-down { + -webkit-mask-image: url('../images/icons/Down.svg'); + mask-image: url('../images/icons/Down.svg'); + } + + &.icon-edit { + -webkit-mask-image: url('../images/icons/Edit.svg'); + mask-image: url('../images/icons/Edit.svg'); + } + + &.icon-help { + -webkit-mask-image: url('../images/icons/Help-Circle.svg'); + mask-image: url('../images/icons/Help-Circle.svg'); + } + + &.icon-github { + -webkit-mask-image: url('../images/icons/github.svg'); + mask-image: url('../images/icons/github.svg'); + } + + &.icon-left { + -webkit-mask-image: url('../images/icons/Left.svg'); + mask-image: url('../images/icons/Left.svg'); + } + + &.icon-loading { + -webkit-mask-image: url('../images/icons/Loader.svg'); + mask-image: url('../images/icons/Loader.svg'); + } + + &.icon-mail { + -webkit-mask-image: url('../images/icons/mail.svg'); + mask-image: url('../images/icons/mail.svg'); + } + + &.icon-plus { + -webkit-mask-image: url('../images/icons/plus.svg'); + mask-image: url('../images/icons/plus.svg'); + } + + &.icon-plus-circle { + -webkit-mask-image: url('../images/icons/Plus-Circle.svg'); + mask-image: url('../images/icons/Plus-Circle.svg'); + } + + &.icon-preview { + -webkit-mask-image: url('../images/icons/Preview.svg'); + mask-image: url('../images/icons/Preview.svg'); + } + + &.icon-refresh { + -webkit-mask-image: url('../images/icons/Refresh-cw.svg'); + mask-image: url('../images/icons/Refresh-cw.svg'); + } + + &.icon-right { + -webkit-mask-image: url('../images/icons/Right.svg'); + mask-image: url('../images/icons/Right.svg'); + } + + &.icon-search { + -webkit-mask-image: url('../images/icons/Search.svg'); + mask-image: url('../images/icons/Search.svg'); + } + + &.icon-send { + -webkit-mask-image: url('../images/icons/Send.svg'); + mask-image: url('../images/icons/Send.svg'); + } + + &.icon-slack { + -webkit-mask-image: url('../images/icons/slack.svg'); + mask-image: url('../images/icons/slack.svg'); + } + + &.icon-up { + -webkit-mask-image: url('../images/icons/Up.svg'); + mask-image: url('../images/icons/Up.svg'); + } + + &.icon-user { + -webkit-mask-image: url('../images/icons/users.svg'); + mask-image: url('../images/icons/users.svg'); + } + + &.icon-more { + -webkit-mask-image: url('../images/icons/More.svg'); + mask-image: url('../images/icons/More.svg'); + } +} + +.disabled, +:disabled { + > img.icon, + > img.icon.icon-color { + background-color: $icon-bg-disabled; + } +} diff --git a/frontend/amundsen_application/static/css/_icons.scss b/frontend/amundsen_application/static/css/_icons.scss new file mode 100644 index 0000000000..b412760056 --- /dev/null +++ b/frontend/amundsen_application/static/css/_icons.scss @@ -0,0 +1,7 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// Amundsen Default Typography +@import 'icons-default'; +// Per-Client Custom Typography +@import 'icons-custom'; diff --git a/frontend/amundsen_application/static/css/_inputs.scss b/frontend/amundsen_application/static/css/_inputs.scss new file mode 100644 index 0000000000..a32e07a6fe --- /dev/null +++ b/frontend/amundsen_application/static/css/_inputs.scss @@ -0,0 +1,42 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +input { + &::-webkit-input-placeholder, + &::-moz-placeholder, + &:-ms-input-placeholder, + &:-moz-placeholder, + &::placeholder { + color: $text-placeholder !important; + } + + &:-webkit-autofill, + &:-webkit-autofill:hover, + &:-webkit-autofill:focus, + &:-webkit-autofill:active { + -webkit-box-shadow: 0 0 0 1000px $white inset !important; + box-shadow: 0 0 0 1000px $white inset !important; + } + + &[type='radio'] { + margin: 5px; + } + + &[type='text'] { + color: $text-secondary !important; + } + + &:not([disabled]) { + cursor: pointer; + } +} + +textarea { + border: 1px solid $stroke; + border-radius: 5px; + color: $text-secondary !important; + padding: 10px; + width: 100%; +} diff --git a/frontend/amundsen_application/static/css/_labels.scss b/frontend/amundsen_application/static/css/_labels.scss new file mode 100644 index 0000000000..14cea95538 --- /dev/null +++ b/frontend/amundsen_application/static/css/_labels.scss @@ -0,0 +1,29 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.label-negative { + background-color: $badge-negative-color; + color: $badge-text-color; +} + +.label-neutral { + background-color: $badge-neutral-color; + color: $badge-text-color; +} + +.label-primary { + background-color: $badge-primary-color; + color: $badge-text-color; +} + +.label-positive { + background-color: $badge-positive-color; + color: $badge-text-color; +} + +.label-warning { + background-color: $badge-warning-color; + color: $badge-text-color; +} diff --git a/frontend/amundsen_application/static/css/_layouts.scss b/frontend/amundsen_application/static/css/_layouts.scss new file mode 100644 index 0000000000..5e700df2af --- /dev/null +++ b/frontend/amundsen_application/static/css/_layouts.scss @@ -0,0 +1,290 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; +@import 'typography'; + +$resource-header-height: 84px; +$aside-separation-space: $spacer-3; +$screen-lg-container: 1440px; +$header-link-height: 32px; +$icon-header-size: 32px; +$inner-column-size: 175px; +$close-btn-size: 24px; + +.resource-detail-layout { + height: calc(100vh - #{$nav-bar-height} - #{$footer-height}); + + .resource-header { + border-bottom: 2px solid $divider; + display: flex; + height: $resource-header-height; + padding: $spacer-2 $spacer-3; + + .icon-header { + height: $icon-header-size; + margin: 10px; + width: $icon-header-size; + } + + .header-section { + flex-shrink: 0; + + &.header-title { + flex-grow: 1; + + .header-title-text { + display: inline-block; + max-width: calc(100% - 100px); + } + } + + .amundsen-breadcrumb { + // Vertically align the breadcrumb + // (84px header height - 18px breadcrumb height) / 2 for top & bottom - 16px resource-header padding = 17px + + padding-top: 17px; + } + + .header-bullets { + display: inline; + margin: 0 $spacer-1 0 0; + padding: 0; + + li { + display: inline; + + &::after { + content: '\00A0\2022\00A0'; + } + + &:last-child::after { + content: ''; + } + } + } + + &.header-links { + flex-shrink: 0; + + > * { + margin-right: $spacer-2; + } + + .header-link { + display: inline-block; + margin: 0 $spacer-2 0 0; + line-height: $header-link-height; + + .avatar-label { + font-weight: $font-weight-body-bold; + } + } + } + + &.header-external-links { + display: flex; + align-items: center; + } + + &.header-buttons { + flex-shrink: 0; + + > * { + margin-right: $spacer-1; + + &:last-child { + margin-right: 0; + } + } + } + } + } + + // Outer column layout + .single-column-layout { + display: flex; + height: calc(100% - #{$resource-header-height}); + + > .left-panel { + border-right: $spacer-half solid $divider; + flex-basis: $left-panel-width; + flex-shrink: 0; + min-height: min-content; + overflow-y: auto; + padding: 0 $spacer-3 $aside-separation-space; + + > .banner { + border: 1px solid $stroke; + height: 40px; + margin: $spacer-3 $spacer-3 0; + padding: $spacer-1; + } + + .section-title { + @extend %text-title-w3; + + color: $text-tertiary; + margin-bottom: $spacer-1; + } + + .editable-section, + .metadata-section { + margin-top: $aside-separation-space; + position: relative; + } + + .editable-text { + font-size: $w2-font-size; + } + + .avatar-label-component { + .avatar-label { + color: $text-primary; + } + } + + .markdown-wrapper { + font-size: $w2-font-size; + + // Restrict max size of header elements + h1, + h2, + h3 { + font-size: 20px; + font-weight: $font-weight-header-bold; + line-height: 28px; + } + } + } + + > .right-panel { + border-left: $spacer-half solid $divider; + flex-basis: $right-panel-width; + flex-shrink: 0; + min-height: min-content; + overflow-y: auto; + padding: 0 $spacer-3 $aside-separation-space; + + .panel-header { + display: flex; + justify-content: space-between; + } + + .panel-title { + @extend %text-title-w1; + + margin-top: $aside-separation-space; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + + .btn-close { + flex-basis: $close-btn-size; + flex-shrink: 0; + margin-top: $aside-separation-space; + } + + .buttons-row { + display: flex; + gap: $spacer-2; + margin-top: $aside-separation-space; + } + + .btn.btn-default { + line-height: $spacer-2; + } + + .section-title { + @extend %text-title-w3; + + color: $text-tertiary; + margin-bottom: $spacer-1; + } + + .editable-section, + .metadata-section { + margin-top: $aside-separation-space; + position: relative; + } + + .editable-text { + font-size: $w2-font-size; + } + } + + > .main-content-panel { + flex-basis: $main-content-panel-width; + flex-grow: 1; + flex-shrink: 0; + overflow-y: scroll; + width: 0; // Required for text truncation + } + + @media (max-width: 1200px) { + > .left-panel { + flex-basis: $left-panel-smaller-width; + } + + > .right-panel { + flex-basis: $right-panel-smaller-width; + } + } + } + + // Inner column layout + .two-column-layout { + display: flex; + + > .left-column { + flex-basis: $inner-column-size; + flex-direction: column; + margin-right: 12px; + } + + > .right-column { + flex-basis: $inner-column-size; + margin-left: 12px; + } + } + + .left-panel, + .right-panel, + .main-content-panel { + display: flex; + flex-direction: column; + } +} + +// Main Layout +#main { + min-width: $body-min-width; +} + +@media (min-width: $screen-lg-max) { + #main > .container { + width: $screen-lg-container; + } +} + +#main > .container { + margin: 96px auto 48px; +} + +@media (max-width: $screen-md-max) { + #main > .container { + margin: 64px auto 48px; + } +} + +@media (max-width: $screen-sm-max) { + #main > .container { + margin: 32px auto 48px; + } +} + +.my-auto { + margin-bottom: auto; + margin-top: auto; +} diff --git a/frontend/amundsen_application/static/css/_list-group.scss b/frontend/amundsen_application/static/css/_list-group.scss new file mode 100644 index 0000000000..dc440a0c79 --- /dev/null +++ b/frontend/amundsen_application/static/css/_list-group.scss @@ -0,0 +1,20 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.list-group { + margin: 0; + + .list-group-item { + border-left: none; + border-right: none; + padding: 0; + + &.clickable:hover { + box-shadow: $hover-box-shadow; + cursor: pointer; + z-index: 1; + } + } +} diff --git a/frontend/amundsen_application/static/css/_pagination.scss b/frontend/amundsen_application/static/css/_pagination.scss new file mode 100644 index 0000000000..4c3334b7fd --- /dev/null +++ b/frontend/amundsen_application/static/css/_pagination.scss @@ -0,0 +1,39 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.pagination { + display: flex; + justify-content: center; + + li { + > a, + > span { + border: 1px solid $stroke; + color: $brand-color-4; + + &:focus, + &:hover { + background-color: $body-bg-tertiary; + color: $link-hover-color; + z-index: 0; + } + } + + &.active { + > a, + > span { + &, + &:active, + &:hover, + &:focus { + background-color: $brand-color-4; + border-color: $brand-color-4; + color: $white; + z-index: 0; + } + } + } + } +} diff --git a/frontend/amundsen_application/static/css/_popovers.scss b/frontend/amundsen_application/static/css/_popovers.scss new file mode 100644 index 0000000000..f44dcd71f9 --- /dev/null +++ b/frontend/amundsen_application/static/css/_popovers.scss @@ -0,0 +1,97 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// TODO - Override Bootstrap variables and delete this. +@import 'variables'; +@import 'typography'; + +.popover { + background-color: $body-bg-dark; + border: 1px solid $body-bg-dark; + color: $text-inverse; + font-size: 12px; + padding: 5px; +} + +.popover-title { + border-bottom: 1px solid $stroke; + color: $text-inverse; + font-size: 14px; + padding: 5px; +} + +.popover-content { + padding: 2px 5px; + word-break: break-word; +} + +.popover.right .arrow::after { + border-right-color: $body-bg-dark; +} + +.popover.bottom .arrow::after { + border-bottom-color: $body-bg-dark; +} + +.popover.top .arrow::after { + border-top-color: $body-bg-dark; +} + +.popover.left .arrow::after { + border-left-color: $body-bg-dark; +} + +.tooltip-inner { + background-color: $body-bg-dark; + border-radius: 3px; + padding: 0; +} + +.tooltip-inner button { + background-color: $body-bg-dark; + border: none; + border-radius: 3px; + color: $body-bg; + font-size: 14px; + font-weight: $font-weight-body-bold; + height: 36px; + outline: none; + width: 96px; +} + +.tooltip-inner button:hover { + color: $text-secondary; +} + +.error-tooltip { + display: flex; + padding: 5px; +} + +.error-tooltip button { + height: 24px; + margin: auto; + width: 24px; +} + +// Modals +.modal-header { + text-align: left; + border-bottom: none; + padding: $spacer-3 $spacer-3 $spacer-2; +} + +.modal-title { + @extend %text-title-w1; +} + +.modal-body { + text-align: left; + padding: $spacer-2 $spacer-3; + overflow: auto; +} + +.modal-footer { + border-top: none; + padding: $spacer-2 $spacer-3 $spacer-3; +} diff --git a/frontend/amundsen_application/static/css/_typography-custom.scss b/frontend/amundsen_application/static/css/_typography-custom.scss new file mode 100644 index 0000000000..7d570cec3b --- /dev/null +++ b/frontend/amundsen_application/static/css/_typography-custom.scss @@ -0,0 +1,4 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// This file is intentionally left blank and should be overwritten by in the build process. diff --git a/frontend/amundsen_application/static/css/_typography-default.scss b/frontend/amundsen_application/static/css/_typography-default.scss new file mode 100644 index 0000000000..23780e6099 --- /dev/null +++ b/frontend/amundsen_application/static/css/_typography-default.scss @@ -0,0 +1,396 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +// New Typography styles based on LPL +// Use these styles going forward + +// Placeholder selectors +// Ref: http://thesassway.com/intermediate/understanding-placeholder-selectors +// +// Example of use: +// .header-title-text { +// @extend %text-headline-w2; +// } + +%text-headline-w1 { + font-family: $text-heading-font-family; + font-size: $w1-headline-font-size; + line-height: $w1-headline-line-height; + font-weight: $title-font-weight; +} + +%text-headline-w2 { + font-family: $text-heading-font-family; + font-size: $w2-headline-font-size; + line-height: $w2-headline-line-height; + font-weight: $title-font-weight; +} + +%text-headline-w3 { + font-family: $text-heading-font-family; + font-size: $w3-headline-font-size; + line-height: $w3-headline-line-height; + font-weight: $title-font-weight; +} + +%text-title-w1 { + font-family: $text-body-font-family; + font-size: $w1-font-size; + line-height: $w1-line-height; + font-weight: $title-font-weight; +} + +%text-title-w2 { + font-family: $text-body-font-family; + font-size: $w2-font-size; + line-height: $w2-line-height; + font-weight: $title-font-weight; +} + +%text-title-w3 { + font-family: $text-body-font-family; + font-size: $w3-font-size; + line-height: $w3-line-height; + font-weight: $title-font-weight; +} + +%text-subtitle-w1 { + font-family: $text-body-font-family; + font-size: $w1-font-size; + line-height: $w1-line-height; + font-weight: $subtitle-font-weight; +} + +%text-subtitle-w2 { + font-family: $text-body-font-family; + font-size: $w2-font-size; + line-height: $w2-line-height; + font-weight: $subtitle-font-weight; +} + +%text-subtitle-w3 { + font-family: $text-body-font-family; + font-size: $w3-font-size; + line-height: $w3-line-height; + font-weight: $subtitle-font-weight; +} + +%text-body-w1 { + font-family: $text-body-font-family; + font-size: $w1-font-size; + line-height: $w1-line-height; + font-weight: $body-font-weight; +} + +%text-body-w2 { + font-family: $text-body-font-family; + font-size: $w2-font-size; + line-height: $w2-line-height; + font-weight: $body-font-weight; +} + +%text-body-w3 { + font-family: $text-body-font-family; + font-size: $w3-font-size; + line-height: $w3-line-height; + font-weight: $body-font-weight; +} + +%text-monospace-w3 { + font-family: $font-family-monospace-code; + font-size: $w3-font-size; + line-height: $w3-line-height; + font-weight: $body-font-weight; +} + +%text-caption-w1 { + font-size: $w1-caption-font-size; + line-height: $w1-caption-line-height; + font-weight: $caption-font-weight; + text-transform: uppercase; +} + +%text-caption-w2 { + font-size: $w2-caption-font-size; + line-height: $w2-caption-line-height; + font-weight: $caption-font-weight; + text-transform: uppercase; +} + +// Typography classes +// Headlines +.text-headline-w1 { + @extend %text-headline-w1; +} + +.text-headline-w2 { + @extend %text-headline-w2; +} + +.text-headline-w3 { + @extend %text-headline-w3; +} + +// Titles +.text-title-w1 { + @extend %text-title-w1; +} + +.text-title-w2 { + @extend %text-title-w2; +} + +.text-title-w3 { + @extend %text-title-w3; +} + +// Subtitles +.text-subtitle-w1 { + @extend %text-subtitle-w1; +} + +.text-subtitle-w2 { + @extend %text-subtitle-w2; +} + +.text-subtitle-w3 { + @extend %text-subtitle-w3; +} + +// Body +.text-body-w1 { + @extend %text-body-w1; +} + +.text-body-w2 { + @extend %text-body-w2; +} + +.text-body-w3 { + @extend %text-body-w3; +} + +// Captions +.text-caption-w1 { + @extend %text-caption-w1; +} + +.text-caption-w2 { + @extend %text-caption-w2; +} + +// Monospace +.text-monospace-w3 { + @extend %text-monospace-w3; +} + +// Typography Helpers +.text-center { + text-align: center; +} + +.text-left { + text-align: left; +} + +.text-right { + text-align: right; +} + +.truncated { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +// Text for Screen Readers only +// Reference: Bootstrap 4 codebase +.sr-only { + position: absolute !important; + width: 1px !important; + height: 1px !important; + padding: 0 !important; + margin: -1px !important; + overflow: hidden !important; + clip: rect(0, 0, 0, 0) !important; + white-space: nowrap !important; + border: 0 !important; +} + +// From https://gist.github.com/igorescobar/d74a76629bab47d601d71c3a6e010ff2 +@mixin truncate($font-size, $line-height, $lines-to-show) { + display: block; // Fallback for non-webkit + display: -webkit-box; + font-size: $font-size; + line-height: $line-height; + -webkit-line-clamp: $lines-to-show; + -webkit-box-orient: vertical; + overflow: hidden; + text-overflow: ellipsis; +} + +// Old typography styles +// DEPRECATED - Don't use! +h1, +h2, +h3, +h4, +h5, +h6 { + margin: 0; +} + +h1, +h2, +h3 { + color: $text-primary; + font-family: $font-family-header; +} + +h1 { + font-size: 36px; + font-weight: $font-weight-header-regular; + line-height: 34px; +} + +h2 { + font-size: 26px; + font-weight: $font-weight-header-bold; + line-height: 34px; +} + +h3 { + font-size: 20px; + font-weight: $font-weight-header-bold; + line-height: 34px; +} + +body { + color: $text-primary; + font-family: $font-family-body; + font-size: 14px; + font-weight: $font-weight-body-regular; +} + +.title-2, +.title-3, +.subtitle-1, +.subtitle-2, +.body-1, +.body-2, +.body-3, +.body-secondary-3, +.body-placeholder, +.body-link, +.caption { + font-family: $font-family-body; +} + +.title-2 { + font-size: 16px; + font-weight: $font-weight-body-bold; + line-height: 1.42857; +} + +.title-3 { + color: $text-secondary; + font-size: 14px; + font-weight: $font-weight-body-bold; + line-height: 1.42857; +} + +.subtitle-1 { + font-size: 20px; + font-weight: $font-weight-body-semi-bold; + line-height: 28px; +} + +.subtitle-2 { + font-size: 16px; + font-weight: $font-weight-body-semi-bold; +} + +.body-1 { + font-size: 20px; + font-weight: $font-weight-body-regular; + line-height: 20px; +} + +.body-2 { + font-size: 16px; + font-weight: $font-weight-body-regular; + line-height: 20px; +} + +.body-3 { + font-size: 14px; + font-weight: $font-weight-body-regular; +} + +.body-secondary-3 { + color: $text-secondary; + font-size: 14px; + font-weight: $font-weight-body-regular; +} + +.body-placeholder { + color: $text-placeholder; + font-size: 14px; + font-weight: $font-weight-body-regular; +} + +.body-link { + color: $brand-color-4; + font-size: $font-size-large; + + &:link, + &:visited, + &:hover, + &:active { + text-decoration: none; + } +} + +.caption { + color: $text-secondary; + font-size: 13px; + font-weight: $font-weight-body-bold; +} + +.column-name { + @extend %text-title-w3; + + color: $column-name-color; +} + +.column-type-label { + @extend %text-title-w3; + + color: $text-primary; +} + +.resource-type { + @extend %text-subtitle-w3; + + color: $text-placeholder; +} + +.helper-text { + color: $text-secondary; + font-family: $font-family-body; + font-size: 12px; +} + +.text-placeholder { + color: $text-placeholder; +} + +.text-secondary { + color: $text-secondary; +} + +.text-primary { + color: $text-primary; +} diff --git a/frontend/amundsen_application/static/css/_typography.scss b/frontend/amundsen_application/static/css/_typography.scss new file mode 100644 index 0000000000..8f8b924b37 --- /dev/null +++ b/frontend/amundsen_application/static/css/_typography.scss @@ -0,0 +1,7 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// Amundsen Default Typography +@import 'typography-default'; +// Per-Client Custom Typography +@import 'typography-custom'; diff --git a/frontend/amundsen_application/static/css/_variables-custom.scss b/frontend/amundsen_application/static/css/_variables-custom.scss new file mode 100644 index 0000000000..5fe2eef4bf --- /dev/null +++ b/frontend/amundsen_application/static/css/_variables-custom.scss @@ -0,0 +1,4 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// This file is intentionally left blank and should be used to add new custom variables or overwrite defaults from _variables-default.scss diff --git a/frontend/amundsen_application/static/css/_variables-default.scss b/frontend/amundsen_application/static/css/_variables-default.scss new file mode 100644 index 0000000000..1a48c15012 --- /dev/null +++ b/frontend/amundsen_application/static/css/_variables-default.scss @@ -0,0 +1,198 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'colors'; + +// TODO - consider using more descriptive names, or replacing with more specific variables. +// Colors +$brand-color-1: $indigo10 !default; +$brand-color-2: $indigo20 !default; +$brand-color-3: $indigo40 !default; +$brand-color-4: $indigo60 !default; +$brand-color-5: $indigo80 !default; + +$brand-primary: $brand-color-4 !default; + +/* Scaffolding */ +$body-bg: $white !default; +$body-bg-secondary: $gray0 !default; +$body-bg-tertiary: $gray5 !default; +$body-bg-dark: $gray100 !default; +$divider: $gray15 !default; +$stroke: $gray20 !default; +$stroke-light: $gray10 !default; +$stroke-focus: $gray60 !default; +$stroke-underline: $gray40 !default; + +// Typography +$text-primary: $gray100 !default; +$text-secondary: $gray60 !default; +$text-tertiary: $gray40 !default; +$text-placeholder: $gray40 !default; +$text-inverse: $white !default; + +$column-name-color: $indigo70; +$link-color: $brand-color-4; +$link-hover-color: $brand-color-5; + +$font-family-body: 'Open Sans', sans-serif !default; +$font-weight-body-regular: 400 !default; +$font-weight-body-semi-bold: 600 !default; +$font-weight-body-bold: 700 !default; + +$font-family-header: 'Roboto', sans-serif !default; +$font-weight-header-regular: 500 !default; +$font-weight-header-bold: 700 !default; + +$font-family-monospace-code: 'Space Mono', menlo, monospace !default; +$font-family-serif: georgia, 'Times New Roman', times, serif !default; + +$font-size-small: 12px !default; +$font-size-base: 14px !default; +$font-size-large: 16px !default; +$line-height-small: 1.5 !default; +$line-height-large: 1.5 !default; + +// Badges +$badge-text-color: $text-primary; +$badge-negative-color: $sunset20; +$badge-neutral-color: $gray20; +$badge-primary-color: $cyan10; +$badge-positive-color: $mint20; +$badge-warning-color: $amber30; + +$badge-overlay: $gray100; +$badge-opacity-light: 0.14; +$badge-opacity-dark: 0.16; +$badge-pressed-light: 0.21; +$badge-pressed-dark: 0.22; + +$badge-height: 20px; + +// Buttons +$btn-border-radius-base: 4px; + +$btn-primary-bg: $brand-color-4 !default; +$btn-primary-bg-hover: $brand-color-5 !default; +$btn-primary-border: transparent !default; +$btn-primary-border-hover: transparent !default; +$btn-primary-color: $white !default; +$btn-primary-color-hover: $white !default; + +$btn-default-bg: $white !default; +$btn-default-bg-hover: $gray5 !default; +$btn-default-border: $gray20 !default; +$btn-default-border-hover: $gray30 !default; +$btn-default-color: $gray100 !default; +$btn-default-color-hover: $gray90 !default; + +// Icons +$icon-bg: $gray20 !default; +$icon-bg-brand: $brand-color-3 !default; +$icon-bg-dark: $gray60 !default; +$icon-bg-disabled: $gray20 !default; +$red-triangle-warning: $sunset60; + +// Header, Body, & Footer +$nav-bar-color: $indigo100; +$light-nav-bar-color: $white; +$nav-bar-height: 48px; +$body-min-width: 1048px; +$footer-height: 60px; +$navbar-item-line-height: 45px; + +// Extra breakpoints +$screen-lg-max: 1490px; + +// SearchPanel +$search-panel-width: 270px; +$search-panel-border-width: 4px; + +// Layout Panels +$main-content-panel-width: 500px; +$left-panel-width: 425px; +$right-panel-width: 425px; +$left-panel-smaller-width: 415px; +$right-panel-smaller-width: 415px; + +// List Group +$list-group-border: $stroke !default; +$list-group-border-radius: 0 !default; + +// Labels +$label-primary-bg: $brand-color-3 !default; + +//Priority +$priority-text-blocker: $white; +$priority-bg-color: $rose80; + +// Tabs +$tab-content-margin-top: 56px; + +// Tags +$tag-bg: $gray5; +$tag-bg-hover: $gray10; +$tag-border-radius: 4px; + +// TODO Temp Colors +$resource-title-color: $indigo60; + +// Spacing +$spacer-size: 8px; +$spacer-half: $spacer-size/2; +$spacer-1: $spacer-size; +$spacer-2: $spacer-size * 2; +$spacer-3: $spacer-size * 3; +$spacer-4: $spacer-size * 4; +$spacer-5: $spacer-size * 5; +$spacer-6: $spacer-size * 6; + +// Elevations (from LPL) +$elevation-level1: 0 0 1px 0 rgba(0, 0, 0, 0.12), + 0 1px 1px 1px rgba(0, 0, 0, 0.08); +$elevation-level2: 0 0 1px 0 rgba(0, 0, 0, 0.12), + 0 2px 3px 0 rgba(0, 0, 0, 0.16); +$elevation-level3: 0 0 1px 0 rgba(0, 0, 0, 0.12), + 0 2px 4px 0 rgba(0, 0, 0, 0.16); +$elevation-level4: 0 0 1px 0 rgba(0, 0, 0, 0.12), + 0 3px 6px 0 rgba(0, 0, 0, 0.16); + +// New Typography variables based on LPL +$text-heading-font-family: $font-family-header; +$text-body-font-family: $font-family-body; + +$w1-font-size: 20px; +$w1-line-height: 24px; + +$w2-font-size: 16px; +$w2-line-height: 20px; + +$w3-font-size: 14px; +$w3-line-height: 18px; + +$w1-headline-font-size: 36px; +$w1-headline-line-height: 44px; + +$w2-headline-font-size: 26px; +$w2-headline-line-height: 32px; + +$w3-headline-font-size: 22px; +$w3-headline-line-height: 28px; + +$w1-caption-font-size: $w2-font-size; +$w1-caption-line-height: $w2-line-height; + +$w2-caption-font-size: 12px; +$w2-caption-line-height: 16px; + +$code-font-size: 12px; + +$title-font-weight: $font-weight-body-bold; +$subtitle-font-weight: $font-weight-body-semi-bold; +$body-font-weight: $font-weight-body-regular; +$caption-font-weight: $font-weight-body-bold; + +$hover-box-shadow: 0 0 1px 0 rgba(0, 0, 0, 0.12), + 0 2px 3px 0 rgba(0, 0, 0, 0.16); + +$popover-border-radius: 12px; diff --git a/frontend/amundsen_application/static/css/_variables.scss b/frontend/amundsen_application/static/css/_variables.scss new file mode 100644 index 0000000000..c0a4b7eafc --- /dev/null +++ b/frontend/amundsen_application/static/css/_variables.scss @@ -0,0 +1,9 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +// Amundsen Default Values +@import 'variables-default'; +// Per-Client Custom Values +@import 'variables-custom'; +// Bootstrap Default Values +@import '~bootstrap-sass/assets/stylesheets/bootstrap/variables'; diff --git a/frontend/amundsen_application/static/css/styles.scss b/frontend/amundsen_application/static/css/styles.scss new file mode 100644 index 0000000000..b91f1d68c2 --- /dev/null +++ b/frontend/amundsen_application/static/css/styles.scss @@ -0,0 +1,28 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'bootstrap-custom'; +@import 'animations'; +@import 'avatars'; +@import 'buttons'; +@import 'dropdowns'; +@import 'fonts'; +@import 'icons'; +@import 'inputs'; +@import 'labels'; +@import 'layouts'; +@import 'list-group'; +@import 'pagination'; +@import 'popovers'; +@import 'typography'; + +// Misc +td { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +form { + margin-bottom: 0; +} diff --git a/frontend/amundsen_application/static/fonts/OpenSans-Bold.ttf b/frontend/amundsen_application/static/fonts/OpenSans-Bold.ttf new file mode 100644 index 0000000000..96fabd865d Binary files /dev/null and b/frontend/amundsen_application/static/fonts/OpenSans-Bold.ttf differ diff --git a/frontend/amundsen_application/static/fonts/OpenSans-Regular.ttf b/frontend/amundsen_application/static/fonts/OpenSans-Regular.ttf new file mode 100644 index 0000000000..2d4da3a6e2 Binary files /dev/null and b/frontend/amundsen_application/static/fonts/OpenSans-Regular.ttf differ diff --git a/frontend/amundsen_application/static/fonts/OpenSans-SemiBold.ttf b/frontend/amundsen_application/static/fonts/OpenSans-SemiBold.ttf new file mode 100644 index 0000000000..fd71fe9da8 Binary files /dev/null and b/frontend/amundsen_application/static/fonts/OpenSans-SemiBold.ttf differ diff --git a/frontend/amundsen_application/static/fonts/Roboto-Bold.ttf b/frontend/amundsen_application/static/fonts/Roboto-Bold.ttf new file mode 100644 index 0000000000..e612852d25 Binary files /dev/null and b/frontend/amundsen_application/static/fonts/Roboto-Bold.ttf differ diff --git a/frontend/amundsen_application/static/fonts/Roboto-Medium.ttf b/frontend/amundsen_application/static/fonts/Roboto-Medium.ttf new file mode 100644 index 0000000000..86d1c52ed5 Binary files /dev/null and b/frontend/amundsen_application/static/fonts/Roboto-Medium.ttf differ diff --git a/frontend/amundsen_application/static/fonts/SpaceMono-Regular.ttf b/frontend/amundsen_application/static/fonts/SpaceMono-Regular.ttf new file mode 100644 index 0000000000..3374aca030 Binary files /dev/null and b/frontend/amundsen_application/static/fonts/SpaceMono-Regular.ttf differ diff --git a/frontend/amundsen_application/static/global.d.ts b/frontend/amundsen_application/static/global.d.ts new file mode 100644 index 0000000000..b7a12826b3 --- /dev/null +++ b/frontend/amundsen_application/static/global.d.ts @@ -0,0 +1,10 @@ +export {}; + +declare const require: { + (path: string): T; + (paths: string[], callback: (...modules: any[]) => void): void; + ensure: ( + paths: string[], + callback: (require: (path: string) => T) => void + ) => void; +}; diff --git a/frontend/amundsen_application/static/images/airflow.jpeg b/frontend/amundsen_application/static/images/airflow.jpeg new file mode 100644 index 0000000000..6d2de3a941 Binary files /dev/null and b/frontend/amundsen_application/static/images/airflow.jpeg differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/android-chrome-192x192.png b/frontend/amundsen_application/static/images/favicons/dev/android-chrome-192x192.png new file mode 100644 index 0000000000..3535c0a469 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/android-chrome-192x192.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/android-chrome-256x256.png b/frontend/amundsen_application/static/images/favicons/dev/android-chrome-256x256.png new file mode 100644 index 0000000000..94873114c6 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/android-chrome-256x256.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/apple-touch-icon.png b/frontend/amundsen_application/static/images/favicons/dev/apple-touch-icon.png new file mode 100644 index 0000000000..1b30a1895f Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/apple-touch-icon.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/browserconfig.xml b/frontend/amundsen_application/static/images/favicons/dev/browserconfig.xml new file mode 100644 index 0000000000..51850e60ef --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/dev/browserconfig.xml @@ -0,0 +1,9 @@ + + + + + + #2d89ef + + + diff --git a/frontend/amundsen_application/static/images/favicons/dev/favicon-16x16.png b/frontend/amundsen_application/static/images/favicons/dev/favicon-16x16.png new file mode 100644 index 0000000000..f946b5fc08 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/favicon-16x16.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/favicon-32x32.png b/frontend/amundsen_application/static/images/favicons/dev/favicon-32x32.png new file mode 100644 index 0000000000..41a6cb27bd Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/favicon-32x32.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/favicon.ico b/frontend/amundsen_application/static/images/favicons/dev/favicon.ico new file mode 100644 index 0000000000..fb27f67734 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/favicon.ico differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/mstile-150x150.png b/frontend/amundsen_application/static/images/favicons/dev/mstile-150x150.png new file mode 100644 index 0000000000..c2f3a3983f Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/dev/mstile-150x150.png differ diff --git a/frontend/amundsen_application/static/images/favicons/dev/safari-pinned-tab.svg b/frontend/amundsen_application/static/images/favicons/dev/safari-pinned-tab.svg new file mode 100644 index 0000000000..001dc6d86d --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/dev/safari-pinned-tab.svg @@ -0,0 +1,25 @@ + + + + +Created by potrace 1.11, written by Peter Selinger 2001-2013 + + + + + diff --git a/frontend/amundsen_application/static/images/favicons/dev/site.webmanifest b/frontend/amundsen_application/static/images/favicons/dev/site.webmanifest new file mode 100644 index 0000000000..5819b5705a --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/dev/site.webmanifest @@ -0,0 +1,19 @@ +{ + "name": "Amundsen Dev", + "short_name": "Amundsen Dev", + "icons": [ + { + "src": "/static/images/favicons/dev/android-chrome-192x192.png", + "sizes": "192x192", + "type": "image/png" + }, + { + "src": "/static/images/favicons/dev/android-chrome-256x256.png", + "sizes": "256x256", + "type": "image/png" + } + ], + "theme_color": "#ffffff", + "background_color": "#ffffff", + "display": "standalone" +} diff --git a/frontend/amundsen_application/static/images/favicons/prod/android-chrome-192x192.png b/frontend/amundsen_application/static/images/favicons/prod/android-chrome-192x192.png new file mode 100644 index 0000000000..98c32bd710 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/android-chrome-192x192.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/android-chrome-256x256.png b/frontend/amundsen_application/static/images/favicons/prod/android-chrome-256x256.png new file mode 100644 index 0000000000..66c16e2d84 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/android-chrome-256x256.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/apple-touch-icon.png b/frontend/amundsen_application/static/images/favicons/prod/apple-touch-icon.png new file mode 100644 index 0000000000..c8029b46ba Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/apple-touch-icon.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/browserconfig.xml b/frontend/amundsen_application/static/images/favicons/prod/browserconfig.xml new file mode 100644 index 0000000000..84425a0ac3 --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/prod/browserconfig.xml @@ -0,0 +1,9 @@ + + + + + + #2d89ef + + + diff --git a/frontend/amundsen_application/static/images/favicons/prod/favicon-16x16.png b/frontend/amundsen_application/static/images/favicons/prod/favicon-16x16.png new file mode 100644 index 0000000000..14cdad8313 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/favicon-16x16.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/favicon-32x32.png b/frontend/amundsen_application/static/images/favicons/prod/favicon-32x32.png new file mode 100644 index 0000000000..74dee77c6a Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/favicon-32x32.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/favicon.ico b/frontend/amundsen_application/static/images/favicons/prod/favicon.ico new file mode 100644 index 0000000000..a8216c280b Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/favicon.ico differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/mstile-150x150.png b/frontend/amundsen_application/static/images/favicons/prod/mstile-150x150.png new file mode 100644 index 0000000000..65542be2db Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/prod/mstile-150x150.png differ diff --git a/frontend/amundsen_application/static/images/favicons/prod/safari-pinned-tab.svg b/frontend/amundsen_application/static/images/favicons/prod/safari-pinned-tab.svg new file mode 100644 index 0000000000..001dc6d86d --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/prod/safari-pinned-tab.svg @@ -0,0 +1,25 @@ + + + + +Created by potrace 1.11, written by Peter Selinger 2001-2013 + + + + + diff --git a/frontend/amundsen_application/static/images/favicons/prod/site.webmanifest b/frontend/amundsen_application/static/images/favicons/prod/site.webmanifest new file mode 100644 index 0000000000..dd774fa6b3 --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/prod/site.webmanifest @@ -0,0 +1,19 @@ +{ + "name": "Amundsen", + "short_name": "Amundsen", + "icons": [ + { + "src": "/static/images/favicons/prod/android-chrome-192x192.png", + "sizes": "192x192", + "type": "image/png" + }, + { + "src": "/static/images/favicons/prod/android-chrome-256x256.png", + "sizes": "256x256", + "type": "image/png" + } + ], + "theme_color": "#ffffff", + "background_color": "#ffffff", + "display": "standalone" +} diff --git a/frontend/amundsen_application/static/images/favicons/staging/android-chrome-192x192.png b/frontend/amundsen_application/static/images/favicons/staging/android-chrome-192x192.png new file mode 100644 index 0000000000..5848c61a58 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/android-chrome-192x192.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/android-chrome-256x256.png b/frontend/amundsen_application/static/images/favicons/staging/android-chrome-256x256.png new file mode 100644 index 0000000000..65c67d8034 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/android-chrome-256x256.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/apple-touch-icon.png b/frontend/amundsen_application/static/images/favicons/staging/apple-touch-icon.png new file mode 100644 index 0000000000..fbd25c30ce Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/apple-touch-icon.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/browserconfig.xml b/frontend/amundsen_application/static/images/favicons/staging/browserconfig.xml new file mode 100644 index 0000000000..b15758e783 --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/staging/browserconfig.xml @@ -0,0 +1,9 @@ + + + + + + #2b5797 + + + diff --git a/frontend/amundsen_application/static/images/favicons/staging/favicon-16x16.png b/frontend/amundsen_application/static/images/favicons/staging/favicon-16x16.png new file mode 100644 index 0000000000..15529435e3 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/favicon-16x16.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/favicon-32x32.png b/frontend/amundsen_application/static/images/favicons/staging/favicon-32x32.png new file mode 100644 index 0000000000..282bd437a0 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/favicon-32x32.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/favicon.ico b/frontend/amundsen_application/static/images/favicons/staging/favicon.ico new file mode 100644 index 0000000000..f883bb6515 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/favicon.ico differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/mstile-150x150.png b/frontend/amundsen_application/static/images/favicons/staging/mstile-150x150.png new file mode 100644 index 0000000000..a07262d481 Binary files /dev/null and b/frontend/amundsen_application/static/images/favicons/staging/mstile-150x150.png differ diff --git a/frontend/amundsen_application/static/images/favicons/staging/safari-pinned-tab.svg b/frontend/amundsen_application/static/images/favicons/staging/safari-pinned-tab.svg new file mode 100644 index 0000000000..001dc6d86d --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/staging/safari-pinned-tab.svg @@ -0,0 +1,25 @@ + + + + +Created by potrace 1.11, written by Peter Selinger 2001-2013 + + + + + diff --git a/frontend/amundsen_application/static/images/favicons/staging/site.webmanifest b/frontend/amundsen_application/static/images/favicons/staging/site.webmanifest new file mode 100644 index 0000000000..b94603ce1c --- /dev/null +++ b/frontend/amundsen_application/static/images/favicons/staging/site.webmanifest @@ -0,0 +1,19 @@ +{ + "name": "Amundsen", + "short_name": "Amundsen", + "icons": [ + { + "src": "/static/images/favicons/staging/android-chrome-192x192.png", + "sizes": "192x192", + "type": "image/png" + }, + { + "src": "/static/images/favicons/staging/android-chrome-256x256.png", + "sizes": "256x256", + "type": "image/png" + } + ], + "theme_color": "#ffffff", + "background_color": "#ffffff", + "display": "standalone" +} diff --git a/frontend/amundsen_application/static/images/github.png b/frontend/amundsen_application/static/images/github.png new file mode 100644 index 0000000000..2e155e2646 Binary files /dev/null and b/frontend/amundsen_application/static/images/github.png differ diff --git a/frontend/amundsen_application/static/images/icons/Alert-Triangle.svg b/frontend/amundsen_application/static/images/icons/Alert-Triangle.svg new file mode 100644 index 0000000000..59e65b15bc --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Alert-Triangle.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Close.svg b/frontend/amundsen_application/static/images/icons/Close.svg new file mode 100644 index 0000000000..161a70916f --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Close.svg @@ -0,0 +1,19 @@ + + + + Close + Created with Sketch. + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/DataQualityWarning.svg b/frontend/amundsen_application/static/images/icons/DataQualityWarning.svg new file mode 100644 index 0000000000..b01b155d29 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/DataQualityWarning.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/amundsen_application/static/images/icons/Database.svg b/frontend/amundsen_application/static/images/icons/Database.svg new file mode 100644 index 0000000000..4b1b36831a --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Database.svg @@ -0,0 +1,16 @@ + + + + Database + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Delta-Down.svg b/frontend/amundsen_application/static/images/icons/Delta-Down.svg new file mode 100644 index 0000000000..c44499888a --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Delta-Down.svg @@ -0,0 +1,19 @@ + + + + Delta-Down + Created with Sketch. + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Delta-Up.svg b/frontend/amundsen_application/static/images/icons/Delta-Up.svg new file mode 100644 index 0000000000..1e2ec0771b --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Delta-Up.svg @@ -0,0 +1,19 @@ + + + + Delta-Up + Created with Sketch. + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Dimension.svg b/frontend/amundsen_application/static/images/icons/Dimension.svg new file mode 100644 index 0000000000..a375077906 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Dimension.svg @@ -0,0 +1,16 @@ + + + + Dimension + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Down Arrow.svg b/frontend/amundsen_application/static/images/icons/Down Arrow.svg new file mode 100644 index 0000000000..d51a47dc71 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Down Arrow.svg @@ -0,0 +1,16 @@ + + + + Down Arrow + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Down-Arrow.svg b/frontend/amundsen_application/static/images/icons/Down-Arrow.svg new file mode 100644 index 0000000000..ba3c709282 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Down-Arrow.svg @@ -0,0 +1,16 @@ + + + + Down-Arrow + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Down.svg b/frontend/amundsen_application/static/images/icons/Down.svg new file mode 100644 index 0000000000..f1a1fe0752 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Down.svg @@ -0,0 +1,16 @@ + + + + Down + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Edit.svg b/frontend/amundsen_application/static/images/icons/Edit.svg new file mode 100644 index 0000000000..6051196725 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Edit.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/Edit_Inverted.svg b/frontend/amundsen_application/static/images/icons/Edit_Inverted.svg new file mode 100644 index 0000000000..330fcd5530 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Edit_Inverted.svg @@ -0,0 +1,18 @@ + + + + Edit + Created with Sketch. + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/Expand.svg b/frontend/amundsen_application/static/images/icons/Expand.svg new file mode 100644 index 0000000000..b58c7629a7 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Expand.svg @@ -0,0 +1,16 @@ + + + + Expand + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Favorite-Filled.svg b/frontend/amundsen_application/static/images/icons/Favorite-Filled.svg new file mode 100644 index 0000000000..db6b3c1ae0 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Favorite-Filled.svg @@ -0,0 +1,16 @@ + + + + Favorite-Filled + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Favorite.svg b/frontend/amundsen_application/static/images/icons/Favorite.svg new file mode 100644 index 0000000000..1c765bb27d --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Favorite.svg @@ -0,0 +1,16 @@ + + + + Favorite + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Help-Circle.svg b/frontend/amundsen_application/static/images/icons/Help-Circle.svg new file mode 100644 index 0000000000..6c210ad162 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Help-Circle.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/Info-Filled.svg b/frontend/amundsen_application/static/images/icons/Info-Filled.svg new file mode 100644 index 0000000000..77ab089c35 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Info-Filled.svg @@ -0,0 +1,16 @@ + + + + Info-Filled + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Info.svg b/frontend/amundsen_application/static/images/icons/Info.svg new file mode 100644 index 0000000000..4150ab76af --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Info.svg @@ -0,0 +1,16 @@ + + + + Info + Created with Sketch. + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/Left.svg b/frontend/amundsen_application/static/images/icons/Left.svg new file mode 100644 index 0000000000..74977ee303 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Left.svg @@ -0,0 +1,18 @@ + + + + Left + Created with Sketch. + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Lineage.svg b/frontend/amundsen_application/static/images/icons/Lineage.svg new file mode 100644 index 0000000000..68b2c4a628 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Lineage.svg @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/Loader.svg b/frontend/amundsen_application/static/images/icons/Loader.svg new file mode 100644 index 0000000000..967d83859f --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Loader.svg @@ -0,0 +1,32 @@ + + + + Loader + Created with Sketch. + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Metric.svg b/frontend/amundsen_application/static/images/icons/Metric.svg new file mode 100644 index 0000000000..9a17367109 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Metric.svg @@ -0,0 +1,16 @@ + + + + Metric + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Minimize.svg b/frontend/amundsen_application/static/images/icons/Minimize.svg new file mode 100644 index 0000000000..2c7ad1b008 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Minimize.svg @@ -0,0 +1,16 @@ + + + + Minimize + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/More.svg b/frontend/amundsen_application/static/images/icons/More.svg new file mode 100644 index 0000000000..1308f05ef7 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/More.svg @@ -0,0 +1,16 @@ + + + + More + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Null Value.svg b/frontend/amundsen_application/static/images/icons/Null Value.svg new file mode 100644 index 0000000000..f5661fa5ba --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Null Value.svg @@ -0,0 +1,16 @@ + + + + Null Value + Created with Sketch. + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/Null-Value.svg b/frontend/amundsen_application/static/images/icons/Null-Value.svg new file mode 100644 index 0000000000..d61a08b7e3 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Null-Value.svg @@ -0,0 +1,16 @@ + + + + Null-Value + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Parameters.svg b/frontend/amundsen_application/static/images/icons/Parameters.svg new file mode 100644 index 0000000000..88ce590ba6 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Parameters.svg @@ -0,0 +1,18 @@ + + + + Parameters + Created with Sketch. + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Person.svg b/frontend/amundsen_application/static/images/icons/Person.svg new file mode 100644 index 0000000000..ad26bdeab3 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Person.svg @@ -0,0 +1,16 @@ + + + + Person + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Plus-Circle.svg b/frontend/amundsen_application/static/images/icons/Plus-Circle.svg new file mode 100644 index 0000000000..754fb6794d --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Plus-Circle.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/Preview-Fillled.svg b/frontend/amundsen_application/static/images/icons/Preview-Fillled.svg new file mode 100644 index 0000000000..ff67beede8 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Preview-Fillled.svg @@ -0,0 +1,19 @@ + + + + Preview-Fillled + Created with Sketch. + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Preview.svg b/frontend/amundsen_application/static/images/icons/Preview.svg new file mode 100644 index 0000000000..bc5a2cf55a --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Preview.svg @@ -0,0 +1,19 @@ + + + + Preview + Created with Sketch. + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Refresh-cw.svg b/frontend/amundsen_application/static/images/icons/Refresh-cw.svg new file mode 100644 index 0000000000..5097efd944 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Refresh-cw.svg @@ -0,0 +1,2 @@ + + diff --git a/frontend/amundsen_application/static/images/icons/Right.svg b/frontend/amundsen_application/static/images/icons/Right.svg new file mode 100644 index 0000000000..eb11253a42 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Right.svg @@ -0,0 +1,18 @@ + + + + Right + Created with Sketch. + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Search.svg b/frontend/amundsen_application/static/images/icons/Search.svg new file mode 100644 index 0000000000..382a724270 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Search.svg @@ -0,0 +1,16 @@ + + + + Search + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Send.svg b/frontend/amundsen_application/static/images/icons/Send.svg new file mode 100644 index 0000000000..42ef2a2438 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Send.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Speech.svg b/frontend/amundsen_application/static/images/icons/Speech.svg new file mode 100644 index 0000000000..04e13c5a29 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Speech.svg @@ -0,0 +1,16 @@ + + + + Speech + Created with Sketch. + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/Trash.svg b/frontend/amundsen_application/static/images/icons/Trash.svg new file mode 100644 index 0000000000..f24d55bf64 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Trash.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Up-Arrow.svg b/frontend/amundsen_application/static/images/icons/Up-Arrow.svg new file mode 100644 index 0000000000..295e5b7c02 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Up-Arrow.svg @@ -0,0 +1,19 @@ + + + + Up-Arrow + Created with Sketch. + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/Up.svg b/frontend/amundsen_application/static/images/icons/Up.svg new file mode 100644 index 0000000000..9c805b8acb --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/Up.svg @@ -0,0 +1,18 @@ + + + + Up + Created with Sketch. + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/amundsen-logo-dark.svg b/frontend/amundsen_application/static/images/icons/amundsen-logo-dark.svg new file mode 100644 index 0000000000..7ad7c2c406 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/amundsen-logo-dark.svg @@ -0,0 +1 @@ +amundsen_mark_blue \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/amundsen-logo-light.svg b/frontend/amundsen_application/static/images/icons/amundsen-logo-light.svg new file mode 100644 index 0000000000..603a45a5e6 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/amundsen-logo-light.svg @@ -0,0 +1 @@ +amundsen_mark_orange \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/application.svg b/frontend/amundsen_application/static/images/icons/application.svg new file mode 100644 index 0000000000..482801f3da --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/application.svg @@ -0,0 +1,4 @@ + + + + diff --git a/frontend/amundsen_application/static/images/icons/check.svg b/frontend/amundsen_application/static/images/icons/check.svg new file mode 100644 index 0000000000..ca1e1383f1 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/check.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/dashboard.svg b/frontend/amundsen_application/static/images/icons/dashboard.svg new file mode 100644 index 0000000000..ad4566183a --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/dashboard.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/frontend/amundsen_application/static/images/icons/github.svg b/frontend/amundsen_application/static/images/icons/github.svg new file mode 100644 index 0000000000..803f8d6795 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/github.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/logo-bigquery.svg b/frontend/amundsen_application/static/images/icons/logo-bigquery.svg new file mode 100644 index 0000000000..064581a0ff --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-bigquery.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-databricks-sql.svg b/frontend/amundsen_application/static/images/icons/logo-databricks-sql.svg new file mode 100644 index 0000000000..aa4b193ba0 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-databricks-sql.svg @@ -0,0 +1,23 @@ + + + + +Created by potrace 1.10, written by Peter Selinger 2001-2011 + + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-databricks.png b/frontend/amundsen_application/static/images/icons/logo-databricks.png new file mode 100644 index 0000000000..a80d0d54ed Binary files /dev/null and b/frontend/amundsen_application/static/images/icons/logo-databricks.png differ diff --git a/frontend/amundsen_application/static/images/icons/logo-delta.png b/frontend/amundsen_application/static/images/icons/logo-delta.png new file mode 100644 index 0000000000..89d7527e0d Binary files /dev/null and b/frontend/amundsen_application/static/images/icons/logo-delta.png differ diff --git a/frontend/amundsen_application/static/images/icons/logo-dremio.svg b/frontend/amundsen_application/static/images/icons/logo-dremio.svg new file mode 100644 index 0000000000..e37f4cc9a3 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-dremio.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/logo-druid.svg b/frontend/amundsen_application/static/images/icons/logo-druid.svg new file mode 100644 index 0000000000..52db86db3f --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-druid.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-elasticsearch.svg b/frontend/amundsen_application/static/images/icons/logo-elasticsearch.svg new file mode 100644 index 0000000000..b95507cd54 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-elasticsearch.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-hive.svg b/frontend/amundsen_application/static/images/icons/logo-hive.svg new file mode 100644 index 0000000000..031ad38977 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-hive.svg @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-mode.svg b/frontend/amundsen_application/static/images/icons/logo-mode.svg new file mode 100644 index 0000000000..406efd875b --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-mode.svg @@ -0,0 +1,14 @@ + + + + mode-logo + Created with Sketch. + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-oracle.svg b/frontend/amundsen_application/static/images/icons/logo-oracle.svg new file mode 100644 index 0000000000..46c8fefbbe --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-oracle.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-postgres.svg b/frontend/amundsen_application/static/images/icons/logo-postgres.svg new file mode 100644 index 0000000000..251aa652d5 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-postgres.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-powerbi.svg b/frontend/amundsen_application/static/images/icons/logo-powerbi.svg new file mode 100644 index 0000000000..d7ab0423b9 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-powerbi.svg @@ -0,0 +1,36 @@ + + + + PBI Logo + Created with Sketch. + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-presto.svg b/frontend/amundsen_application/static/images/icons/logo-presto.svg new file mode 100644 index 0000000000..bf0e662a38 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-presto.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-redash.svg b/frontend/amundsen_application/static/images/icons/logo-redash.svg new file mode 100644 index 0000000000..c867141cdf --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-redash.svg @@ -0,0 +1,14 @@ + + + + redash-logo + Created with Sketch. + + + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-redshift.svg b/frontend/amundsen_application/static/images/icons/logo-redshift.svg new file mode 100644 index 0000000000..a4eaa60d33 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-redshift.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-snowflake.svg b/frontend/amundsen_application/static/images/icons/logo-snowflake.svg new file mode 100644 index 0000000000..0a1abe0b08 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-snowflake.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-superset.svg b/frontend/amundsen_application/static/images/icons/logo-superset.svg new file mode 100644 index 0000000000..452bce3245 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-superset.svg @@ -0,0 +1,11 @@ + + superset + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-tableau.svg b/frontend/amundsen_application/static/images/icons/logo-tableau.svg new file mode 100644 index 0000000000..28996f1dad --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-tableau.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/logo-teradata.svg b/frontend/amundsen_application/static/images/icons/logo-teradata.svg new file mode 100644 index 0000000000..564495d372 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-teradata.svg @@ -0,0 +1,218 @@ + + + + diff --git a/frontend/amundsen_application/static/images/icons/logo-trino.svg b/frontend/amundsen_application/static/images/icons/logo-trino.svg new file mode 100644 index 0000000000..a5d50dc275 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/logo-trino.svg @@ -0,0 +1 @@ + diff --git a/frontend/amundsen_application/static/images/icons/mail.svg b/frontend/amundsen_application/static/images/icons/mail.svg new file mode 100644 index 0000000000..2af169e83d --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/mail.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/plus.svg b/frontend/amundsen_application/static/images/icons/plus.svg new file mode 100644 index 0000000000..703c5b7b23 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/plus.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/presto-logo.svg b/frontend/amundsen_application/static/images/icons/presto-logo.svg new file mode 100644 index 0000000000..8f4b31f6c3 --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/presto-logo.svg @@ -0,0 +1,59 @@ + + + + +Presto logo + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/amundsen_application/static/images/icons/slack.svg b/frontend/amundsen_application/static/images/icons/slack.svg new file mode 100644 index 0000000000..5d973466bb --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/slack.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/icons/users.svg b/frontend/amundsen_application/static/images/icons/users.svg new file mode 100644 index 0000000000..aacf6b08ec --- /dev/null +++ b/frontend/amundsen_application/static/images/icons/users.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/amundsen_application/static/images/loading_spinner.gif b/frontend/amundsen_application/static/images/loading_spinner.gif new file mode 100644 index 0000000000..bb1983607f Binary files /dev/null and b/frontend/amundsen_application/static/images/loading_spinner.gif differ diff --git a/frontend/amundsen_application/static/images/watermark-range.png b/frontend/amundsen_application/static/images/watermark-range.png new file mode 100644 index 0000000000..72bcf60636 Binary files /dev/null and b/frontend/amundsen_application/static/images/watermark-range.png differ diff --git a/frontend/amundsen_application/static/jest.config.js b/frontend/amundsen_application/static/jest.config.js new file mode 100644 index 0000000000..a1c55a79dc --- /dev/null +++ b/frontend/amundsen_application/static/jest.config.js @@ -0,0 +1,59 @@ +module.exports = { + coverageThreshold: { + './js/config': { + branches: 90, + functions: 90, + lines: 90, + statements: 90, + }, + './js/components': { + branches: 67, // 75 + functions: 67, // 75 + lines: 75, // 75 + statements: 75, // 75 + }, + './js/pages': { + branches: 65, // 75 + functions: 72, // 75 + lines: 81, // 75 + statements: 78, // 75 + }, + './js/ducks': { + branches: 60, // 75 + functions: 80, + lines: 80, + statements: 80, + }, + './js/fixtures': { + branches: 100, + functions: 100, + lines: 100, + statements: 100, + }, + }, + roots: ['/js'], + setupFiles: ['/test-setup.ts'], + transform: { + '^.+\\.tsx?$': 'ts-jest', + '^.+\\.js$': 'babel-jest', + }, + testRegex: '(test|spec)\\.(j|t)sx?$', + moduleDirectories: ['node_modules', 'js'], + coveragePathIgnorePatterns: [ + 'stories/*', + 'constants.ts', + 'testDataBuilder.ts', + '.story.tsx', + 'js/index.tsx', + ], + moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json'], + moduleNameMapper: { + '^.+\\.(css|scss)$': '/node_modules/jest-css-modules', + '^axios$': 'axios/dist/node/axios.cjs', + }, + globals: { + 'ts-jest': { + diagnostics: false, + }, + }, +}; diff --git a/frontend/amundsen_application/static/js/components/Alert/Alert.tsx b/frontend/amundsen_application/static/js/components/Alert/Alert.tsx new file mode 100644 index 0000000000..f03b575501 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/Alert.tsx @@ -0,0 +1,180 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import SanitizedHTML from 'react-sanitized-html'; +import { Modal } from 'react-bootstrap'; + +import { IconSizes } from 'interfaces'; +import { NoticeSeverity } from 'config/config-types'; +import { AlertIcon, InformationIcon } from 'components/SVGIcons'; +import { DefinitionList } from 'components/DefinitionList'; + +import { logClick } from 'utils/analytics'; + +import './styles.scss'; + +const SEVERITY_TO_COLOR_MAP = { + [NoticeSeverity.INFO]: '#3a97d3', // cyan50 + [NoticeSeverity.WARNING]: '#ffb146', // $amber50 + [NoticeSeverity.ALERT]: '#b8072c', // $red70 +}; +const SEVERITY_TO_SEVERITY_CLASS = { + [NoticeSeverity.INFO]: 'is-info', + [NoticeSeverity.WARNING]: 'is-warning', + [NoticeSeverity.ALERT]: 'is-alert', +}; +export const OPEN_PAYLOAD_CTA = 'See details'; +export const PAYLOAD_MODAL_TITLE = 'Summary'; +const PAYLOAD_MODAL_CLOSE_BTN = 'Close'; +const PAYLOAD_DEFINITION_WIDTH = 180; + +export interface AlertProps { + /** Message to show in the alert */ + message: string | React.ReactNode; + /** Severity of the alert (info, warning, or alert) */ + severity?: NoticeSeverity; + /** Link passed to set as the action (for routing links) */ + actionLink?: React.ReactNode; + /** Text of the link action */ + actionText?: string; + /** Href for the link action */ + actionHref?: string; + /** Callback to call when the action is clicked */ + onAction?: (event: React.MouseEvent) => void; + /** Optional extra info to render in a modal */ + payload?: Record; +} + +export const Alert: React.FC = ({ + message, + severity = NoticeSeverity.WARNING, + onAction, + actionText, + actionHref, + actionLink, + payload, +}) => { + const [showPayloadModal, setShowPayloadModal] = React.useState(false); + let action: null | React.ReactNode = null; + + const handleSeeDetails = (e: React.MouseEvent) => { + onAction?.(e); + setShowPayloadModal(true); + logClick(e, { + label: 'See Notice Details', + target_id: 'notice-detail-button', + }); + }; + const handleModalClose = (e: React.MouseEvent) => { + setShowPayloadModal(false); + logClick(e, { + label: 'Close Notice Details', + target_id: 'notice-detail-close', + }); + }; + + if (payload) { + action = ( + + ); + } + + if (actionText && onAction) { + action = ( + + ); + } + + if (actionText && actionHref) { + action = ( + + {actionText} + + ); + } + + if (actionLink) { + action = actionLink; + } + + let iconComponent: React.ReactNode = null; + + if (severity === NoticeSeverity.INFO) { + iconComponent = ( + + ); + } else { + iconComponent = ( + + ); + } + + // If we receive a string, we want to sanitize any html inside + const formattedMessage = + typeof message === 'string' ? : message; + + const payloadDefinitions = payload + ? Object.keys(payload).map((key) => ({ + term: key, + description: payload[key], + })) + : null; + + return ( +
+ {iconComponent} +

{formattedMessage}

+ {action && {action}} + {payloadDefinitions && ( + + + {PAYLOAD_MODAL_TITLE} + + + + + + + + + )} +
+ ); +}; + +export default Alert; diff --git a/frontend/amundsen_application/static/js/components/Alert/AlertList.tsx b/frontend/amundsen_application/static/js/components/Alert/AlertList.tsx new file mode 100644 index 0000000000..73d6672fa4 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/AlertList.tsx @@ -0,0 +1,69 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; + +import { NoticeSeverity, NoticeType } from 'config/config-types'; +import { Alert } from './Alert'; + +export interface AlertListProps { + notices: NoticeType[]; +} + +export interface AggregatedAlertListProps { + notices: { + [key: string]: NoticeType; + }; +} + +const aggregateNotices = (notices) => + notices.reduce((accum, notice: NoticeType) => { + if (notice) { + const { messageHtml, severity, payload } = notice; + + if (typeof messageHtml !== 'function') { + if (payload) { + accum[messageHtml] ??= {}; + accum[messageHtml][severity] ??= { + payload: { descriptions: [] }, + }; + accum[messageHtml][severity].payload.descriptions.push(payload); + } else { + accum[messageHtml] = { + [severity]: { ...notice }, + }; + } + } + } + + return accum; + }, {}); + +export const AlertList: React.FC = ({ notices }) => { + if (!notices.length) { + return null; + } + + const aggregated = aggregateNotices(notices); + const NoticeSeverityValues = Object.values(NoticeSeverity); + + return ( +
+ {Object.keys(aggregated).map((notice, idx) => + Object.keys(aggregated[notice]) + .sort( + (a: NoticeSeverity, b: NoticeSeverity) => + NoticeSeverityValues.indexOf(a) - NoticeSeverityValues.indexOf(b) + ) + .map((severity) => ( + + )) + )} +
+ ); +}; diff --git a/frontend/amundsen_application/static/js/components/Alert/alert.story.tsx b/frontend/amundsen_application/static/js/components/Alert/alert.story.tsx new file mode 100644 index 0000000000..5b97446ff9 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/alert.story.tsx @@ -0,0 +1,173 @@ +/* eslint-disable no-alert */ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import React from 'react'; +import { Meta } from '@storybook/react/types-6-0'; + +import { NoticeSeverity } from 'config/config-types'; + +import StorySection from '../StorySection'; +import { Alert, AlertList } from '.'; + +export const AlertStory = (): React.ReactNode => ( + <> + + { + alert('action executed!'); + }} + /> + + + + + + { + alert('action executed!'); + }} + message="Lorem ipsum dolor sit amet consectetur adipisicing elit. Laboriosam perspiciatis non ipsa officia expedita magnam mollitia, excepturi iste eveniet qui nisi eum illum, quas voluptas, reprehenderit quam molestias cum quisquam!" + /> + + +); + +AlertStory.storyName = 'with basic options'; + +export const AlertWithActionStory = (): React.ReactNode => ( + <> + + + Alert text that has a link + + } + /> + + + { + alert('action executed!'); + }} + /> + + + + + + + Custom Link + + } + /> + + + Link | Ownser)', + 'Failed DAGs': + '1 out of 4 DAGs failed (Link | Ownser)', + 'Root cause': + 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod', + Estimate: 'Target fix by MM/DD/YYYY 00:00', + }} + /> + + +); + +AlertWithActionStory.storyName = 'with different types of actions'; + +export const AlertWithSeverityStory = (): React.ReactNode => ( + <> + + + Info alert text that has a link + + } + /> + + + + Warning alert text that has a link + + } + /> + + + + Alert alert text that has a link + + } + /> + + +); + +AlertWithSeverityStory.storyName = 'with different severities'; + +const list = [ + { severity: NoticeSeverity.INFO, messageHtml: 'First alert of the stack' }, + { + severity: NoticeSeverity.WARNING, + messageHtml: 'Second alert of the stack', + }, + { severity: NoticeSeverity.ALERT, messageHtml: 'Third alert of the stack' }, + { + severity: NoticeSeverity.ALERT, + messageHtml: 'Aggregated alert of the stack', + payload: { + term: 'Test term 1', + description: 'Test description 1', + }, + }, + { + severity: NoticeSeverity.ALERT, + messageHtml: 'Aggregated alert of the stack', + payload: { + term: 'Test term 2', + description: 'Test description 2', + }, + }, +]; + +export const AlertListStory = (): React.ReactNode => ( + <> + + + + +); + +AlertListStory.storyName = 'with AlertList'; + +export default { + title: 'Components/Alert', + component: Alert, + decorators: [], +} as Meta; diff --git a/frontend/amundsen_application/static/js/components/Alert/index.rtl.spec.tsx b/frontend/amundsen_application/static/js/components/Alert/index.rtl.spec.tsx new file mode 100644 index 0000000000..988d530ef9 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/index.rtl.spec.tsx @@ -0,0 +1,252 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 +import * as React from 'react'; +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import '@testing-library/jest-dom'; + +import { NoticeSeverity } from 'config/config-types'; + +import { Alert, AlertProps, OPEN_PAYLOAD_CTA, PAYLOAD_MODAL_TITLE } from '.'; + +const setup = (propOverrides?: Partial) => { + const props: AlertProps = { + message: 'Test Message', + onAction: jest.fn(), + ...propOverrides, + }; + + render(); + + const user = userEvent.setup(); + + return { + props, + user, + }; +}; + +describe('Alert', () => { + describe('render', () => { + it('should render an alert icon', () => { + setup(); + const expected = 1; + const actual = screen.getAllByTestId('warning-icon').length; + + expect(actual).toBe(expected); + }); + + it('should render the alert message text', () => { + const { props } = setup(); + const expected = props.message as string; + const actual = screen.getByText(expected); + + expect(actual).toBeInTheDocument(); + }); + + describe('when passing an action text and action handler', () => { + it('should render the action button', () => { + setup({ actionText: 'Action Text' }); + const expected = 1; + const actual = screen.getAllByRole('button').length; + + expect(actual).toBe(expected); + }); + + it('should render the action text', () => { + const { props } = setup({ actionText: 'Action Text' }); + const actual = screen.getByText(props.actionText as string); + + expect(actual).toBeInTheDocument(); + }); + }); + + describe('when passing an action text and action href', () => { + it('should render the action link', () => { + setup({ + actionHref: 'http://testSite.com', + actionText: 'Action Text', + }); + const expected = 1; + const actual = screen.getAllByRole('link').length; + + expect(actual).toBe(expected); + }); + + it('should render the action text', () => { + const { props } = setup({ + actionHref: 'http://testSite.com', + actionText: 'Action Text', + }); + const actual = screen.getByText(props.actionText as string); + + expect(actual).toBeInTheDocument(); + }); + }); + + describe('when passing a custom action link', () => { + it('should render the custom action link', () => { + setup({ + actionLink: ( + + Custom Link + + ), + }); + const expected = 1; + const actual = screen.getAllByRole('link').length; + + expect(actual).toBe(expected); + }); + }); + + describe('when passing a severity', () => { + it('should render the warning icon by default', () => { + setup(); + const expected = 1; + const actual = screen.getAllByTestId('warning-icon').length; + + expect(actual).toBe(expected); + }); + + it('should render the info icon when info severity', () => { + setup({ severity: NoticeSeverity.INFO }); + const expected = 1; + const actual = screen.getAllByTestId('info-icon').length; + + expect(actual).toBe(expected); + }); + + it('should render the alert icon when alert severity', () => { + setup({ severity: NoticeSeverity.ALERT }); + const expected = 1; + const actual = screen.getAllByTestId('alert-icon').length; + + expect(actual).toBe(expected); + }); + + it('should render the alert icon when warning severity', () => { + setup({ severity: NoticeSeverity.WARNING }); + const expected = 1; + const actual = screen.getAllByTestId('warning-icon').length; + + expect(actual).toBe(expected); + }); + }); + + describe('when passing a payload', () => { + const testPayload = { + testKey: 'testValue', + testKey2: 'testHTMLVAlue Lyft', + }; + + it('should render the "see details" button link', () => { + setup({ payload: testPayload }); + const seeDetailsButton = screen.getByRole('button', { + name: OPEN_PAYLOAD_CTA, + }); + + expect(seeDetailsButton).toBeInTheDocument(); + }); + }); + }); + + describe('lifetime', () => { + describe('when clicking on the action button', () => { + it('should call the onAction handler', () => { + const handlerSpy = jest.fn(); + const { user } = setup({ + actionText: 'Action Text', + onAction: handlerSpy, + }); + const actionButton = screen.getByRole('button'); + + user.click(actionButton); + + waitFor(() => { + expect(handlerSpy).toHaveBeenCalledTimes(1); + }); + }); + }); + + describe('when clicking on the see details button of a payload alert', () => { + const testPayload = { + testKey: 'testValue', + testKey2: 'testHTMLVAlue Lyft', + }; + + it('should call the onAction handler', () => { + const handlerSpy = jest.fn(); + const { user } = setup({ + onAction: handlerSpy, + payload: testPayload, + }); + + const seeDetailsButton = screen.getByRole('button', { + name: OPEN_PAYLOAD_CTA, + }); + + user.click(seeDetailsButton); + waitFor(() => { + expect(handlerSpy).toHaveBeenCalledTimes(1); + }); + }); + + it('should render the alert payload modal', () => { + const { user } = setup({ payload: testPayload }); + + user.click(screen.getByText(OPEN_PAYLOAD_CTA)); + + waitFor(() => { + const alertPayloadModal = screen.getByRole('dialog'); + + expect(alertPayloadModal).toBeInTheDocument(); + }); + }); + + it('should render the alert payload modal header with the payload', () => { + const { user } = setup({ payload: testPayload }); + + user.click(screen.getByText(OPEN_PAYLOAD_CTA)); + + waitFor(() => { + const alertPayloadModalHeader = screen.getByRole('heading', { + name: PAYLOAD_MODAL_TITLE, + }); + + expect(alertPayloadModalHeader).toBeInTheDocument(); + }); + }); + + it('should render the alert payload modal body with the payload', () => { + const { user } = setup({ + payload: testPayload, + }); + const expected = 1; + + user.click(screen.getByText(OPEN_PAYLOAD_CTA)); + + waitFor(() => { + const actual = screen.queryAllByTestId('alert-payload').length; + + expect(actual).toEqual(expected); + }); + }); + + it('should render the alert payload modal footer with a close button', () => { + const { user } = setup({ + payload: testPayload, + }); + const expected = 1; + + user.click(screen.getByText(OPEN_PAYLOAD_CTA)); + + waitFor(() => { + const actual = screen.queryAllByTestId('alert-payload-close').length; + + expect(actual).toEqual(expected); + }); + }); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/Alert/index.tsx b/frontend/amundsen_application/static/js/components/Alert/index.tsx new file mode 100644 index 0000000000..fb3695c60c --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/index.tsx @@ -0,0 +1,5 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +export * from './Alert'; +export * from './AlertList'; diff --git a/frontend/amundsen_application/static/js/components/Alert/styles.scss b/frontend/amundsen_application/static/js/components/Alert/styles.scss new file mode 100644 index 0000000000..79549fca48 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Alert/styles.scss @@ -0,0 +1,62 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; +@import 'typography'; + +$alert-border-radius: 4px; +$alert-warning-background: #fff0d4; +$alert-alert-background: #ffe4dd; + +.alert-list { + margin-top: $spacer-3; + margin-bottom: 0; + + .alert { + margin-bottom: $spacer-1; + } +} + +.alert { + border-radius: $alert-border-radius; + display: flex; + padding: $spacer-1 $spacer-1 * 1.5 $spacer-1 $spacer-2; + justify-content: flex-start; + box-shadow: $elevation-level2; + border: none; + + &.is-info { + background-color: $body-bg; + } + + &.is-warning { + background-color: $alert-warning-background; + } + + &.is-alert { + background-color: $alert-alert-background; + } + + .alert-message { + @extend %text-body-w2; + + margin: auto auto auto 0; + display: inline; + } + + .alert-triangle-svg-icon, + .info-svg-icon { + flex-shrink: 0; + align-self: center; + margin-right: $spacer-1; + } + + .info-svg-icon { + margin-right: $spacer-half; + margin-left: -$spacer-half; + } + + .alert-action { + margin: auto 0 auto auto; + } +} diff --git a/frontend/amundsen_application/static/js/components/AvatarLabel/index.spec.tsx b/frontend/amundsen_application/static/js/components/AvatarLabel/index.spec.tsx new file mode 100644 index 0000000000..5c65463718 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/AvatarLabel/index.spec.tsx @@ -0,0 +1,65 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import * as Avatar from 'react-avatar'; + +import { shallow } from 'enzyme'; + +import AvatarLabel, { AvatarLabelProps } from '.'; + +describe('AvatarLabel', () => { + const setup = (propOverrides?: Partial) => { + const props: AvatarLabelProps = { + ...propOverrides, + }; + const wrapper = shallow(); + + return { + props, + wrapper, + }; + }; + + describe('render', () => { + let props: AvatarLabelProps; + let wrapper; + + beforeAll(() => { + ({ props, wrapper } = setup({ + avatarClass: 'test', + label: 'testLabel', + labelClass: 'test', + src: 'testSrc', + })); + }); + + it('renders Avatar with correct props', () => { + expect(wrapper.find(Avatar).props()).toMatchObject({ + className: props.avatarClass, + name: props.label, + src: props.src, + size: 24, + round: true, + }); + }); + + describe('renders label', () => { + let element; + + beforeAll(() => { + element = wrapper.find('.avatar-label'); + }); + + it('with correct text', () => { + expect(element.text()).toEqual(props.label); + }); + + it('with correct style', () => { + expect(element.props().className).toBe( + `avatar-label text-body-w2 ${props.labelClass}` + ); + }); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/AvatarLabel/index.tsx b/frontend/amundsen_application/static/js/components/AvatarLabel/index.tsx new file mode 100644 index 0000000000..87e6f250b4 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/AvatarLabel/index.tsx @@ -0,0 +1,36 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import * as Avatar from 'react-avatar'; + +import './styles.scss'; + +export interface AvatarLabelProps { + avatarClass?: string; + labelClass?: string; + label?: string; + round?: boolean; + src?: string; +} + +const AvatarLabel: React.FC = ({ + avatarClass, + labelClass = 'text-secondary', + label = '', + round = true, + src = '', +}: AvatarLabelProps) => ( +
+ + {label} +
+); + +export default AvatarLabel; diff --git a/frontend/amundsen_application/static/js/components/AvatarLabel/styles.scss b/frontend/amundsen_application/static/js/components/AvatarLabel/styles.scss new file mode 100644 index 0000000000..e92f9957d3 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/AvatarLabel/styles.scss @@ -0,0 +1,30 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.avatar-label-component { + display: inline-block; + + .avatar-label { + cursor: inherit; + margin-left: 8px; + min-width: 0; + vertical-align: middle; + } + + .gray-avatar { + div { + background: $gray20 !important; + color: $gray20 !important; + } + } +} + +.avatar-overlap { + margin-left: -5px; + + &:first-child { + margin-left: 0; + } +} diff --git a/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.spec.tsx b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.spec.tsx new file mode 100644 index 0000000000..3f2d6a1e0e --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.spec.tsx @@ -0,0 +1,152 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; + +import { shallow } from 'enzyme'; + +import globalState from 'fixtures/globalState'; + +import { ResourceType } from 'interfaces'; + +import * as Analytics from 'utils/analytics'; + +import { + BookmarkIcon, + BookmarkIconProps, + mapDispatchToProps, + mapStateToProps, +} from '.'; + +const logClickSpy = jest.spyOn(Analytics, 'logClick'); + +logClickSpy.mockImplementation(() => null); + +const setup = (propOverrides?: Partial) => { + const props: BookmarkIconProps = { + bookmarkKey: 'someKey', + isBookmarked: true, + large: false, + addBookmark: jest.fn(), + removeBookmark: jest.fn(), + resourceType: ResourceType.table, + ...propOverrides, + }; + const wrapper = shallow(); + + return { props, wrapper }; +}; + +describe('BookmarkIcon', () => { + describe('handleClick', () => { + const clickEvent = { + preventDefault: jest.fn(), + stopPropagation: jest.fn(), + }; + + it('stops propagation and prevents default', () => { + const { wrapper } = setup(); + + wrapper.find('div').simulate('click', clickEvent); + + expect(clickEvent.preventDefault).toHaveBeenCalled(); + expect(clickEvent.stopPropagation).toHaveBeenCalled(); + }); + + it('bookmarks an unbookmarked resource', () => { + const { props, wrapper } = setup({ + isBookmarked: false, + }); + + wrapper.find('div').simulate('click', clickEvent); + + expect(props.addBookmark).toHaveBeenCalledWith( + props.bookmarkKey, + props.resourceType + ); + }); + + it('unbookmarks a bookmarked resource', () => { + const { props, wrapper } = setup({ + isBookmarked: true, + }); + + wrapper.find('div').simulate('click', clickEvent); + + expect(props.removeBookmark).toHaveBeenCalledWith( + props.bookmarkKey, + props.resourceType + ); + }); + }); + + describe('render', () => { + it('renders an empty bookmark when not bookmarked', () => { + const { wrapper } = setup({ isBookmarked: false }); + + expect(wrapper.find('.icon-bookmark').exists()).toBe(true); + }); + + it('renders a filled star when bookmarked', () => { + const { wrapper } = setup({ isBookmarked: true }); + + expect(wrapper.find('.icon-bookmark-filled').exists()).toBe(true); + }); + + it('renders a large star when specified', () => { + const { wrapper } = setup({ large: true }); + + expect(wrapper.find('.bookmark-large').exists()).toBe(true); + }); + }); +}); + +describe('mapDispatchToProps', () => { + let dispatch; + let props; + + beforeAll(() => { + dispatch = jest.fn(() => Promise.resolve()); + props = mapDispatchToProps(dispatch); + }); + + it('sets addBookmark on the props', () => { + expect(props.addBookmark).toBeInstanceOf(Function); + }); + + it('sets removeBookmark on the props', () => { + expect(props.removeBookmark).toBeInstanceOf(Function); + }); +}); + +describe('mapStateToProps', () => { + it('sets the bookmarkKey on the props', () => { + const ownProps = { + bookmarkKey: 'test_bookmark_key', + resourceType: ResourceType.table, + }; + const result = mapStateToProps(globalState, ownProps); + + expect(result.bookmarkKey).toEqual(ownProps.bookmarkKey); + }); + + it('sets isBookmarked to false when the resource key is not bookmarked', () => { + const ownProps = { + bookmarkKey: 'not_bookmarked_key', + resourceType: ResourceType.table, + }; + const result = mapStateToProps(globalState, ownProps); + + expect(result.isBookmarked).toBe(false); + }); + + it('sets isBookmarked to true when the resource key is bookmarked', () => { + const ownProps = { + bookmarkKey: 'bookmarked_key', + resourceType: ResourceType.table, + }; + const result = mapStateToProps(globalState, ownProps); + + expect(result.isBookmarked).toBe(true); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.tsx b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.tsx new file mode 100644 index 0000000000..21206f2e91 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/index.tsx @@ -0,0 +1,100 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { bindActionCreators } from 'redux'; +import { connect } from 'react-redux'; + +import { addBookmark, removeBookmark } from 'ducks/bookmark/reducer'; +import { + AddBookmarkRequest, + RemoveBookmarkRequest, +} from 'ducks/bookmark/types'; +import { GlobalState } from 'ducks/rootReducer'; +import { logClick } from 'utils/analytics'; + +import { ResourceType } from 'interfaces'; + +import './styles.scss'; + +interface StateFromProps { + isBookmarked: boolean; +} + +interface DispatchFromProps { + addBookmark: (key: string, type: ResourceType) => AddBookmarkRequest; + removeBookmark: (key: string, type: ResourceType) => RemoveBookmarkRequest; +} + +interface OwnProps { + bookmarkKey: string; + large?: boolean; + resourceType: ResourceType; +} + +export type BookmarkIconProps = StateFromProps & DispatchFromProps & OwnProps; + +export class BookmarkIcon extends React.Component { + public static defaultProps: Partial = { + large: false, + }; + + handleClick = (e: React.MouseEvent) => { + e.stopPropagation(); + e.preventDefault(); + const { + isBookmarked, + removeBookmark, + bookmarkKey, + resourceType, + addBookmark, + } = this.props; + + if (isBookmarked) { + logClick(e, { + label: 'Remove Bookmark', + target_id: `remove-${resourceType}-bookmark-button`, + }); + removeBookmark(bookmarkKey, resourceType); + } else { + logClick(e, { + label: 'Add Bookmark', + target_id: `add-${resourceType}-bookmark-button`, + }); + addBookmark(bookmarkKey, resourceType); + } + }; + + render() { + const { large, isBookmarked } = this.props; + + return ( +
+ +
+ ); + } +} + +export const mapStateToProps = (state: GlobalState, ownProps: OwnProps) => ({ + bookmarkKey: ownProps.bookmarkKey, + isBookmarked: state.bookmarks.myBookmarks[ownProps.resourceType].some( + (bookmark) => bookmark.key === ownProps.bookmarkKey + ), +}); + +export const mapDispatchToProps = (dispatch: any) => + bindActionCreators({ addBookmark, removeBookmark }, dispatch); + +export default connect( + mapStateToProps, + mapDispatchToProps +)(BookmarkIcon); diff --git a/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/styles.scss b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/styles.scss new file mode 100644 index 0000000000..f5a53cc6d2 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Bookmark/BookmarkIcon/styles.scss @@ -0,0 +1,52 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.bookmark-icon { + border-radius: 50%; + cursor: pointer; + display: inline-block; + height: 32px; + margin-left: 4px; + padding: 4px; + vertical-align: top; + width: 32px; + + &.bookmark-large { + height: 40px; + width: 40px; + + .icon { + height: 32px; + -webkit-mask-size: 32px; + mask-size: 32px; + width: 32px; + } + } + + &:hover, + &:focus { + background-color: $body-bg-tertiary; + } + + .icon { + margin: 0; + + &.icon-bookmark { + &, + &:hover, + &:focus { + background-color: $stroke !important; + } + } + + &.icon-bookmark-filled { + &, + &:hover, + &:focus { + background-color: gold !important; + } + } + } +} diff --git a/frontend/amundsen_application/static/js/components/Button/bootstrap-button.story.tsx b/frontend/amundsen_application/static/js/components/Button/bootstrap-button.story.tsx new file mode 100644 index 0000000000..e45563b3bd --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Button/bootstrap-button.story.tsx @@ -0,0 +1,75 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import React from 'react'; + +import StorySection from '../StorySection'; + +export const BootstrapButtonStory = () => ( + <> + + + + + + + + {/* eslint-disable-next-line jsx-a11y/anchor-is-valid */} + + Button on Link + + + + + + + + + + + + + + + + + + + + + + + + + + + +); + +BootstrapButtonStory.storyName = 'Bootstrap Buttons'; + +export default { + title: 'Components/Buttons', +}; diff --git a/frontend/amundsen_application/static/js/components/Button/custom-button.story.tsx b/frontend/amundsen_application/static/js/components/Button/custom-button.story.tsx new file mode 100644 index 0000000000..467f78be72 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Button/custom-button.story.tsx @@ -0,0 +1,78 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import React from 'react'; + +import { Binoculars } from 'components/SVGIcons'; + +import StorySection from '../StorySection'; + +export const CustomButtonStory = () => ( + <> + + + + + + + + + + + + + + + + + + +
+
+ +
+
+
+ +
+ Group Message + + + +
+
+ +); + +CustomButtonStory.storyName = 'Custom Buttons'; + +export default { + title: 'Components/Buttons', +}; diff --git a/frontend/amundsen_application/static/js/components/Card/card.story.tsx b/frontend/amundsen_application/static/js/components/Card/card.story.tsx new file mode 100644 index 0000000000..c5626ed62f --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Card/card.story.tsx @@ -0,0 +1,34 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import React from 'react'; + +import StorySection from '../StorySection'; +import Card from '.'; + +export default { + title: 'Components/Cards', +}; + +export const Cards = () => ( + <> + + + + + + + + + + +); + +Cards.storyName = 'Cards'; diff --git a/frontend/amundsen_application/static/js/components/Card/index.spec.tsx b/frontend/amundsen_application/static/js/components/Card/index.spec.tsx new file mode 100644 index 0000000000..5ac0d7ad9f --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Card/index.spec.tsx @@ -0,0 +1,190 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { Link, BrowserRouter } from 'react-router-dom'; +import { mount } from 'enzyme'; + +import Card, { CardProps } from '.'; + +const setup = (propOverrides?: Partial) => { + const props = { + ...propOverrides, + }; + const wrapper = mount( + + + + ); + + return { props, wrapper }; +}; + +describe('Card', () => { + describe('render', () => { + it('renders without issues', () => { + expect(() => { + setup(); + }).not.toThrow(); + }); + + it('renders the main container', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.card').length; + + expect(actual).toEqual(expected); + }); + + describe('header', () => { + it('renders a header section', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.card-header').length; + + expect(actual).toEqual(expected); + }); + + describe('subtitle', () => { + it('renders a title if passed', () => { + const { wrapper } = setup({ title: 'test title' }); + const expected = 1; + const actual = wrapper.find('.card-title').length; + + expect(actual).toEqual(expected); + }); + + it('does not render a title if missing', () => { + const { wrapper } = setup(); + const expected = 0; + const actual = wrapper.find('.card-title').length; + + expect(actual).toEqual(expected); + }); + }); + + describe('subtitle', () => { + it('renders a subtitle if passed', () => { + const { wrapper } = setup({ subtitle: 'test subtitle' }); + const expected = 1; + const actual = wrapper.find('.card-subtitle').length; + + expect(actual).toEqual(expected); + }); + + it('does not render a subtitle if missing', () => { + const { wrapper } = setup(); + const expected = 0; + const actual = wrapper.find('.card-subtitle').length; + + expect(actual).toEqual(expected); + }); + }); + }); + + describe('body', () => { + it('renders a body section', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.card-body').length; + + expect(actual).toEqual(expected); + }); + + describe('copy', () => { + it('renders a copy if passed', () => { + const { wrapper } = setup({ copy: 'test copy' }); + const expected = 1; + const actual = wrapper.find('.card-copy').length; + + expect(actual).toEqual(expected); + }); + + it('does not render a copy if missing', () => { + const { wrapper } = setup(); + const expected = 0; + const actual = wrapper.find('.card-copy').length; + + expect(actual).toEqual(expected); + }); + }); + }); + + describe('when is loading', () => { + it('holds a loading state', () => { + const { wrapper } = setup({ isLoading: true }); + const expected = 1; + const actual = wrapper.find('.card.is-loading').length; + + expect(actual).toEqual(expected); + }); + + it('renders a shimmer loader', () => { + const { wrapper } = setup({ isLoading: true }); + const expected = 1; + const actual = wrapper.find('.card-shimmer-loader').length; + + expect(actual).toEqual(expected); + }); + + it('renders five rows of line loaders', () => { + const { wrapper } = setup({ isLoading: true }); + const expected = 5; + const actual = wrapper.find('.card-shimmer-row').length; + + expect(actual).toEqual(expected); + }); + }); + + describe('when an href is passed', () => { + it('should render a link', () => { + const testPath = 'fakePath'; + const { wrapper } = setup({ href: testPath }); + const expected = 1; + const actual = wrapper.find('a.card').length; + + expect(actual).toEqual(expected); + }); + + it('renders a react router Link', () => { + const testPath = 'fakePath'; + const { wrapper } = setup({ href: testPath }); + const expected = 1; + const actual = wrapper.find(Link).length; + + expect(actual).toEqual(expected); + }); + + it('sets the link to the passed href', () => { + const testPath = 'fakePath'; + const { wrapper } = setup({ href: testPath }); + const expected = '/' + testPath; + const actual = wrapper + .find('a.card') + .getDOMNode() + .attributes.getNamedItem('href')?.value; + + expect(actual).toEqual(expected); + }); + }); + }); + + describe('lifetime', () => { + describe('when clicking on an interactive card', () => { + it('should call the onClick handler', () => { + const clickSpy = jest.fn(); + const { wrapper } = setup({ + onClick: clickSpy, + href: 'testPath', + }); + const expected = 1; + + wrapper.find(Link).simulate('click'); + + const actual = clickSpy.mock.calls.length; + + expect(actual).toEqual(expected); + }); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/Card/index.tsx b/frontend/amundsen_application/static/js/components/Card/index.tsx new file mode 100644 index 0000000000..4bcaf194e2 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Card/index.tsx @@ -0,0 +1,80 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { Link } from 'react-router-dom'; + +import './styles.scss'; + +export interface CardProps { + title?: string; + subtitle?: string; + copy?: string | JSX.Element; + isLoading?: boolean; + href?: string; + onClick?: (e: React.SyntheticEvent) => void; + type?: string; +} + +const CardShimmerLoader: React.FC = () => ( +
+
+
+
+ +
+
+
+
+
+); + +const Card: React.FC = ({ + href, + title, + subtitle, + copy, + onClick = undefined, + isLoading = false, + type, +}) => { + let card; + let cardContent = ( + <> +
+ {title &&

{title}

} + {subtitle &&

{subtitle}

} +
+
+ {copy &&
{copy}
} +
+ + ); + + if (isLoading) { + cardContent = ; + } + + if (href) { + card = ( + + {cardContent} + + ); + } else { + card = ( +
+ {cardContent} +
+ ); + } + + return <>{card}; +}; + +export default Card; diff --git a/frontend/amundsen_application/static/js/components/Card/styles.scss b/frontend/amundsen_application/static/js/components/Card/styles.scss new file mode 100644 index 0000000000..f18d236f4b --- /dev/null +++ b/frontend/amundsen_application/static/js/components/Card/styles.scss @@ -0,0 +1,95 @@ +@import 'variables'; +@import 'typography'; + +$shimmer-loader-items: 1, 2, 3, 4, 5; +$shimmer-loader-row-height: 16px; +$shimmer-loader-row-min-width: 90; +$shimmer-loader-row-width: 160; + +$card-height: 180px; +$card-header-height: 60px; +$card-border-size: 1px; +$card-focus-border-size: 2px; + +$card-title-max-lines: 2; +$card-copy-max-lines: 3; + +.card { + display: block; + padding: $spacer-3; + border-top: $card-border-size solid $gray20; + border-bottom: $card-border-size solid $gray20; + height: $card-height; + + &.is-link { + &:focus { + text-decoration: none; + border: $card-focus-border-size solid $blue80; + border-radius: $spacer-1/2; + outline-offset: 0; + } + + &:hover, + &:active { + text-decoration: none; + box-shadow: $elevation-level2; + border: 0; + } + } +} + +.card-header { + height: $card-header-height; +} + +.card-title { + @extend %text-title-w2; + + color: $text-primary; + + @include truncate($w2-font-size, $w2-line-height, $card-title-max-lines); +} + +.card-subtitle { + @extend %text-body-w3; + + color: $text-secondary; +} + +.card-copy { + @extend %text-body-w3; + + color: $text-primary; + margin: 0; + + @include truncate($w3-font-size, $w3-line-height, $card-copy-max-lines); +} + +.card-body { + padding-top: $spacer-2; +} + +// Shimmer Loader +.card-shimmer-loader { + width: 100%; +} + +.card-shimmer-row { + height: $shimmer-loader-row-height; + width: $shimmer-loader-row-min-width + px; + margin-bottom: $spacer-1; + + &:last-child { + margin-bottom: 0; + } +} + +@each $line in $shimmer-loader-items { + .shimmer-row-line--#{$line} { + width: $shimmer-loader-row-width + px; + } +} + +.card-shimmer-loader-body { + margin-top: $spacer-4; +} diff --git a/frontend/amundsen_application/static/js/components/DefinitionList/definitionList.story.tsx b/frontend/amundsen_application/static/js/components/DefinitionList/definitionList.story.tsx new file mode 100644 index 0000000000..c22070bbc1 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/DefinitionList/definitionList.story.tsx @@ -0,0 +1,111 @@ +/* eslint-disable no-alert */ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 +import React from 'react'; +import { Meta } from '@storybook/react/types-6-0'; + +import StorySection from '../StorySection'; +import { DefinitionList } from '.'; + +export const DefinitionListStory = (): React.ReactNode => ( + <> + + + + + Link | Owner)', + }, + { + term: 'Failed DAGs', + description: + '1 out of 4 DAGs failed (Link | Owner)', + }, + ]} + /> + + + Link | Owner)', + }, + ]} + /> + + + Owner', + }, + { + 'Failed Check 2': 'coco.fact_rides', + Owner: 'Just a normal string', + }, + ], + }, + ]} + /> + + + Owner', + }, + { + 'Failed Check 2': 'coco.fact_rides', + Owner: 'Just a normal string', + }, + ], + }, + ]} + /> + + +); + +DefinitionListStory.storyName = 'with basic options'; + +export default { + title: 'Components/DefinitionList', + component: DefinitionList, + decorators: [], +} as Meta; diff --git a/frontend/amundsen_application/static/js/components/DefinitionList/index.spec.tsx b/frontend/amundsen_application/static/js/components/DefinitionList/index.spec.tsx new file mode 100644 index 0000000000..0b667676bb --- /dev/null +++ b/frontend/amundsen_application/static/js/components/DefinitionList/index.spec.tsx @@ -0,0 +1,180 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { mount } from 'enzyme'; + +import { DefinitionList, DefinitionListProps } from '.'; + +const setup = (propOverrides?: Partial) => { + const props: DefinitionListProps = { + definitions: [{ term: 'testTerm', description: 'testDescription' }], + ...propOverrides, + }; + // eslint-disable-next-line react/jsx-props-no-spreading + const wrapper = mount(); + + return { + props, + wrapper, + }; +}; + +describe('DefinitionList', () => { + describe('render', () => { + it('should render a definition list', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('dl').length; + + expect(actual).toEqual(expected); + }); + + it('should render one definition container', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.definition-list-container').length; + + expect(actual).toEqual(expected); + }); + + it('should render one definition term', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.definition-list-term').length; + + expect(actual).toEqual(expected); + }); + + it('should render one definition definition', () => { + const { wrapper } = setup(); + const expected = 1; + const actual = wrapper.find('.definition-list-definition').length; + + expect(actual).toEqual(expected); + }); + + describe('when passing several definitions', () => { + it('should render as many containers', () => { + const { wrapper } = setup({ + definitions: [ + { + term: 'Table name', + description: 'coco.fact_rides', + }, + { + term: 'Root cause', + description: + 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod', + }, + { + term: 'Estimate', + description: 'Target fix by MM/DD/YYYY 00:00', + }, + ], + }); + const expected = 3; + const actual = wrapper.find('.definition-list-container').length; + + expect(actual).toEqual(expected); + }); + + it('should render as many terms-definition pairs', () => { + const { wrapper } = setup({ + definitions: [ + { + term: 'Table name', + description: 'coco.fact_rides', + }, + { + term: 'Root cause', + description: + 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod', + }, + { + term: 'Estimate', + description: 'Target fix by MM/DD/YYYY 00:00', + }, + ], + }); + const expected = 3; + const actualTerms = wrapper.find('.definition-list-term').length; + const actualDefinitions = wrapper.find( + '.definition-list-definition' + ).length; + + expect(actualTerms).toEqual(expected); + expect(actualDefinitions).toEqual(expected); + }); + }); + + describe('when passing definitions with html', () => { + it('should render them', () => { + const { wrapper } = setup({ + definitions: [ + { + term: 'Verity checks', + description: + '1 out of 4 checks failed (Link | Ownser)', + }, + { + term: 'Failed DAGs', + description: + '1 out of 4 DAGs failed (Link | Ownser)', + }, + ], + }); + const expected = 2; + const actualTerms = wrapper.find('.definition-list-term').length; + const actualDefinitions = wrapper.find( + '.definition-list-definition' + ).length; + + expect(actualTerms).toEqual(expected); + expect(actualDefinitions).toEqual(expected); + }); + }); + + describe('when passing a custom term width', () => { + it('should set its width', () => { + const { wrapper } = setup({ + termWidth: 200, + }); + const expected = 'min-width: 200px;'; + const actual = wrapper + .find('.definition-list-term') + ?.getDOMNode() + ?.getAttribute('style'); + + expect(actual).toEqual(expected); + }); + }); + + describe('when passing aggregated descriptions', () => { + it('should render them', () => { + const { wrapper } = setup({ + definitions: [ + { + term: 'Table name', + description: [ + { + 'Failed Check': 'coco.fact_rides', + Owner: 'Owner', + }, + { + 'Failed Check 2': 'coco.fact_rides', + Owner: 'Just a normal string', + }, + ], + }, + ], + }); + + const itemGroup = wrapper.find('.definition-list-items-group'); + + expect(itemGroup.length).toEqual(2); + expect(itemGroup.find('.definition-list-term').length).toEqual(4); + }); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/DefinitionList/index.tsx b/frontend/amundsen_application/static/js/components/DefinitionList/index.tsx new file mode 100644 index 0000000000..a5eac10740 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/DefinitionList/index.tsx @@ -0,0 +1,91 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import SanitizedHTML from 'react-sanitized-html'; + +import './styles.scss'; + +export interface DefinitionType { + /** Definition term */ + term: string; + /** Definition body text */ + description: React.ReactNode; +} + +export interface DefinitionListProps { + /** Size to fix the term block, in pixels*/ + termWidth?: number; + /** List of terms and descriptions to render */ + definitions: DefinitionType[]; +} + +export const DefinitionList: React.FC = ({ + definitions, + termWidth, +}) => { + const parseDescription = (description) => { + switch (typeof description) { + case 'object': + return ( + <> + {Array.isArray(description) + ? description.map((item) => { + const items = Object.keys(item).map((key) => ( +
+
+ {key}: +
+
+ {parseDescription(item[key])} +
+
+ )); + + return ( +
{items}
+ ); + }) + : description} + + ); + case 'string': + return ; + default: + return description; + } + }; + + return ( +
+ {definitions.map(({ term, description }) => ( +
+ {Array.isArray(description) ? ( +
+ {parseDescription(description)} +
+ ) : ( + <> +
+ {term} +
+
+ {parseDescription(description)} +
+ + )} +
+ ))} +
+ ); +}; + +export default DefinitionList; diff --git a/frontend/amundsen_application/static/js/components/DefinitionList/styles.scss b/frontend/amundsen_application/static/js/components/DefinitionList/styles.scss new file mode 100644 index 0000000000..87e62b28d2 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/DefinitionList/styles.scss @@ -0,0 +1,63 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 +@import 'variables'; +@import 'typography'; + +.definition-list { + margin: 0; +} + +.definition-list-container { + display: flex; + padding-bottom: $spacer-2; + width: 100%; + + &:last-child { + padding-bottom: 0; + } +} + +.definition-list-term { + @extend %text-title-w3; + + color: $gray40; + padding-right: $spacer-1; + overflow: hidden; + text-overflow: ellipsis; +} + +.definition-list-definition { + @extend %text-body-w3; + + margin-left: 0; + margin-bottom: 0; +} + +.definition-list-inner-container { + display: flex; + width: 100%; + flex-direction: column; + + > *:not(:last-child) { + border-bottom: 1px solid $gray20; + } +} + +.definition-list-items-group { + padding: $spacer-1 0; + + > * { + display: flex; + + :not(:last-child) { + margin-bottom: $spacer-half; + } + > .definition-list-term { + flex: 0; + flex-basis: 25%; + } + > .definition-list-definition { + flex: 1; + } + } +} diff --git a/frontend/amundsen_application/static/js/components/EditableSection/constants.ts b/frontend/amundsen_application/static/js/components/EditableSection/constants.ts new file mode 100644 index 0000000000..cb12a99714 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableSection/constants.ts @@ -0,0 +1 @@ +export const EDIT_TEXT = 'Click to edit'; diff --git a/frontend/amundsen_application/static/js/components/EditableSection/index.spec.tsx b/frontend/amundsen_application/static/js/components/EditableSection/index.spec.tsx new file mode 100644 index 0000000000..8fdcb92d9f --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableSection/index.spec.tsx @@ -0,0 +1,121 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { shallow } from 'enzyme'; + +import TagInput from 'features/Tags/TagInput'; +import { ResourceType } from 'interfaces/Resources'; +import EditableSection, { EditableSectionProps } from '.'; + +describe('EditableSection', () => { + const setup = (propOverrides?: Partial, children?) => { + const props = { + title: 'defaultTitle', + readOnly: false, + ...propOverrides, + }; + const wrapper = shallow( + {children} + ); + + return { wrapper, props }; + }; + + describe('handleClick', () => { + const clickEvent = { + preventDefault: jest.fn(), + }; + + it('preventDefault on click', () => { + const { wrapper } = setup(); + + wrapper + .find('.editable-section-label-wrapper') + .simulate('click', clickEvent); + + expect(clickEvent.preventDefault).toHaveBeenCalled(); + }); + }); + + describe('setEditMode', () => { + const { wrapper } = setup(); + + it('Enters edit mode after calling setEditMode(true)', () => { + wrapper.instance().setEditMode(true); + + expect(wrapper.state().isEditing).toBe(true); + }); + + it('Exits edit mode after calling setEditMode(false)', () => { + wrapper.instance().setEditMode(false); + + expect(wrapper.state().isEditing).toBe(false); + }); + }); + + describe('render', () => { + const mockTitle = 'Mock'; + const convertTextSpy = jest + .spyOn(EditableSection, 'convertText') + .mockImplementation(() => mockTitle); + const { wrapper, props } = setup( + { title: 'custom title' }, + + ); + + it('renders the converted props.title as the section title', () => { + convertTextSpy.mockClear(); + wrapper.instance().render(); + + expect(convertTextSpy).toHaveBeenCalledWith(props.title); + expect(wrapper.find('.section-title').text()).toBe(mockTitle); + }); + + it('renders children with additional props', () => { + const childProps = wrapper.find(TagInput).props(); + + expect(childProps).toMatchObject({ + isEditing: wrapper.state().isEditing, + setEditMode: wrapper.instance().setEditMode, + }); + }); + + it('renders children as-is for non-react elements', () => { + const child = 'non-react-child'; + const { wrapper } = setup(undefined, child); + + expect(wrapper.find('.editable-section-content').text()).toBe(child); + }); + + it('renders edit button correctly when readOnly=false', () => { + expect(wrapper.find('.edit-button').props().onClick).toBe( + wrapper.instance().toggleEdit + ); + }); + + describe('renders edit link correctly when readOnly=true', () => { + let props; + let wrapper; + + beforeAll(() => { + const setupResult = setup( + { readOnly: true, editUrl: 'test', editText: 'hello' }, +
+ ); + + ({ props, wrapper } = setupResult); + }); + + it('link links to editUrl', () => { + expect(wrapper.find('.edit-button').props().href).toBe(props.editUrl); + }); + }); + + it('does not render button if readOnly=true and there is no external editUrl', () => { + const { wrapper } = setup({ readOnly: true },
); + + expect(wrapper.find('.edit-button').exists()).toBeFalsy(); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/EditableSection/index.tsx b/frontend/amundsen_application/static/js/components/EditableSection/index.tsx new file mode 100644 index 0000000000..e526e53cff --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableSection/index.tsx @@ -0,0 +1,160 @@ +/* eslint-disable jsx-a11y/click-events-have-key-events */ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import { OverlayTrigger, Popover } from 'react-bootstrap'; + +import { logClick } from 'utils/analytics'; + +import * as Constants from './constants'; + +import './styles.scss'; + +export interface EditableSectionProps { + title: string; + readOnly?: boolean; + /* Should be used when readOnly=true to prompt users with a relevant explanation for the given use case */ + editText?: string; + /* Should be used when readOnly=true to link to the source where users can edit the given metadata */ + editUrl?: string; +} + +interface EditableSectionState { + isEditing: boolean; +} + +export interface EditableSectionChildProps { + isEditing?: boolean; + setEditMode?: (isEditing: boolean) => void; + readOnly?: boolean; +} + +export class EditableSection extends React.Component< + EditableSectionProps, + EditableSectionState +> { + static defaultProps: Partial = { + editText: Constants.EDIT_TEXT, + }; + + static convertText(str: string): string { + return str + .split(new RegExp('[\\s+_]')) + .map((x) => x.charAt(0).toUpperCase() + x.slice(1).toLowerCase()) + .join(' '); + } + + constructor(props) { + super(props); + + this.state = { + isEditing: false, + }; + } + + setEditMode = (isEditing: boolean) => { + this.setState({ isEditing }); + }; + + toggleEdit = (e: React.MouseEvent) => { + const { isEditing } = this.state; + const { title } = this.props; + const logTitle = EditableSection.convertText(title); + + this.setState({ isEditing: !isEditing }); + logClick(e, { + label: 'Toggle Editable Section', + target_id: `toggle-edit-${logTitle.toLowerCase()}-section`, + }); + }; + + preventDefault = (event: React.MouseEvent) => { + event.preventDefault(); + }; + + renderButton = (): React.ReactNode => { + const { isEditing } = this.state; + + return ( + + ); + }; + + renderReadOnlyButton = (): React.ReactNode => { + const { editText, editUrl } = this.props; + const popoverHoverFocus = ( + {editText} + ); + + if (!editUrl) { + return null; + } + + return ( + + + {Constants.EDIT_TEXT} + + + + ); + }; + + render() { + const { children, title, readOnly = false } = this.props; + const { isEditing } = this.state; + + const childrenWithProps = React.Children.map(children, (child) => { + if (!React.isValidElement(child)) { + return child; + } + + return React.cloneElement(child, { + readOnly, + isEditing, + setEditMode: this.setEditMode, + }); + }); + + return ( +
+ +
{childrenWithProps}
+
+ ); + } +} + +export default EditableSection; diff --git a/frontend/amundsen_application/static/js/components/EditableSection/styles.scss b/frontend/amundsen_application/static/js/components/EditableSection/styles.scss new file mode 100644 index 0000000000..3dd39a2a83 --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableSection/styles.scss @@ -0,0 +1,41 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +@import 'variables'; + +.editable-section { + .editable-section-label { + display: block; + font-weight: 400; + margin-bottom: 0; + } + + .editable-section-label-wrapper { + margin-bottom: $spacer-1; + } + + .section-title { + color: $text-tertiary; + margin-bottom: $spacer-1/2; + font-weight: 700; + } + + .edit-button { + margin-left: $spacer-1/2; + opacity: 0; + + img { + margin-bottom: auto; + } + + &.active { + opacity: 1; + } + } + + &:hover { + .edit-button { + opacity: 1; + } + } +} diff --git a/frontend/amundsen_application/static/js/components/EditableText/constants.ts b/frontend/amundsen_application/static/js/components/EditableText/constants.ts new file mode 100644 index 0000000000..dab82650ef --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableText/constants.ts @@ -0,0 +1,6 @@ +export const REFRESH_MESSAGE = + 'This text is out of date, please refresh the component'; +export const REFRESH_BUTTON_TEXT = 'Refresh'; +export const UPDATE_BUTTON_TEXT = 'Update'; +export const CANCEL_BUTTON_TEXT = 'Cancel'; +export const ADD_MESSAGE = 'Add Description'; diff --git a/frontend/amundsen_application/static/js/components/EditableText/index.spec.tsx b/frontend/amundsen_application/static/js/components/EditableText/index.spec.tsx new file mode 100644 index 0000000000..a6f3c119bb --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableText/index.spec.tsx @@ -0,0 +1,154 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from 'react'; +import * as ReactMarkdown from 'react-markdown'; + +import { shallow } from 'enzyme'; +import { + CANCEL_BUTTON_TEXT, + REFRESH_BUTTON_TEXT, + REFRESH_MESSAGE, + UPDATE_BUTTON_TEXT, +} from 'components/EditableText/constants'; +import EditableText, { EditableTextProps } from '.'; + +const setup = (propOverrides?: Partial) => { + const props = { + editable: true, + isEditing: true, + maxLength: 4000, + onSubmitValue: jest.fn(), + getLatestValue: jest.fn(), + refreshValue: '', + setEditMode: jest.fn(), + value: 'currentValue', + ...propOverrides, + }; + // eslint-disable-next-line react/jsx-props-no-spreading + const wrapper = shallow(); + + return { + props, + wrapper, + }; +}; + +describe('EditableText', () => { + describe('componentDidUpdate', () => { + it('sets isDisabled:true when refresh value does not equal value', () => { + const { wrapper, props } = setup({ + isEditing: true, + refreshValue: 'new value', + value: 'different value', + }); + + wrapper.instance().componentDidUpdate(props); + const state = wrapper.state(); + + expect(state.isDisabled).toBe(true); + }); + }); + + describe('exitEditMode', () => { + it('updates the state', () => { + const { wrapper, props } = setup(); + const instance = wrapper.instance(); + const setEditModeSpy = jest.spyOn(props, 'setEditMode'); + + setEditModeSpy.mockClear(); + instance.exitEditMode(); + + expect(setEditModeSpy).toHaveBeenCalledWith(false); + expect(wrapper.state()).toMatchObject({ + isDisabled: false, + }); + }); + }); + + describe('render', () => { + describe('not in edit mode', () => { + it('renders a ReactMarkdown component', () => { + const { wrapper } = setup({ + isEditing: false, + value: '', + }); + const markdown = wrapper.find(ReactMarkdown); + + expect(markdown.exists()).toBe(true); + }); + + it('renders an edit link if it is editable and the text is empty', () => { + const { wrapper } = setup({ + isEditing: false, + value: '', + }); + const editLink = wrapper.find('.edit-link'); + + expect(editLink.exists()).toBe(true); + }); + + it('does not render an edit link if it is not editable', () => { + const { wrapper } = setup({ editable: false }); + const editLink = wrapper.find('.edit-link'); + + expect(editLink.exists()).toBe(false); + }); + }); + + describe('in edit mode', () => { + it('renders a textarea ', () => { + const { wrapper, props } = setup({ + isEditing: true, + value: '', + }); + const textarea = wrapper.find('textarea'); + + expect(textarea.exists()).toBe(true); + expect(textarea.props()).toMatchObject({ + maxLength: props.maxLength, + defaultValue: wrapper.state().value, + disabled: wrapper.state().isDisabled, + }); + }); + + it('when disabled, renders the refresh message and button', () => { + const { wrapper } = setup({ + isEditing: true, + value: '', + }); + + wrapper.setState({ isDisabled: true }); + const refreshMessage = wrapper.find('.refresh-message'); + + expect(refreshMessage.text()).toBe(REFRESH_MESSAGE); + + const refreshButton = wrapper.find('.refresh-button'); + + expect(refreshButton.text()).toMatch(REFRESH_BUTTON_TEXT); + }); + + it('when not disabled, renders the update text button', () => { + const { wrapper } = setup({ + isEditing: true, + value: '', + }); + + wrapper.setState({ isDisabled: false }); + const updateButton = wrapper.find('.update-button'); + + expect(updateButton.text()).toMatch(UPDATE_BUTTON_TEXT); + }); + + it('renders the cancel button', () => { + const { wrapper } = setup({ + isEditing: true, + value: '', + }); + const cancelButton = wrapper.find('.cancel-button'); + + expect(cancelButton.text()).toMatch(CANCEL_BUTTON_TEXT); + }); + }); + }); +}); diff --git a/frontend/amundsen_application/static/js/components/EditableText/index.tsx b/frontend/amundsen_application/static/js/components/EditableText/index.tsx new file mode 100644 index 0000000000..c7a179912a --- /dev/null +++ b/frontend/amundsen_application/static/js/components/EditableText/index.tsx @@ -0,0 +1,245 @@ +// Copyright Contributors to the Amundsen project. +// SPDX-License-Identifier: Apache-2.0 + +import * as autosize from 'autosize'; +import * as React from 'react'; +import * as ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; + +import { EditableSectionChildProps } from 'components/EditableSection'; +import { logClick } from 'utils/analytics'; + +import { + CANCEL_BUTTON_TEXT, + REFRESH_BUTTON_TEXT, + REFRESH_MESSAGE, + ADD_MESSAGE, + UPDATE_BUTTON_TEXT, +} from './constants'; + +import './styles.scss'; + +export interface StateFromProps { + refreshValue?: string; +} + +export interface DispatchFromProps { + getLatestValue?: (onSuccess?: () => any, onFailure?: () => any) => void; + onSubmitValue?: ( + newValue: string, + onSuccess?: () => any, + onFailure?: () => any + ) => void; +} + +export interface ComponentProps { + editable?: boolean; + maxLength?: number; + value?: string; + allowDangerousHtml?: boolean; +} + +export type EditableTextProps = ComponentProps & + DispatchFromProps & + StateFromProps & + EditableSectionChildProps; + +interface EditableTextState { + value?: string; + isDisabled: boolean; +} + +class EditableText extends React.Component< + EditableTextProps, + EditableTextState +> { + readonly textAreaRef: React.RefObject; + + public static defaultProps: EditableTextProps = { + editable: true, + maxLength: 500, + value: '', + }; + + constructor(props: EditableTextProps) { + super(props); + this.textAreaRef = React.createRef(); + + this.state = { + isDisabled: false, + value: props.value, + }; + } + + componentDidUpdate(prevProps: EditableTextProps) { + const { value: stateValue, isDisabled } = this.state; + const { + value: propValue, + isEditing, + refreshValue, + getLatestValue, + } = this.props; + + if (prevProps.value !== propValue) { + this.setState({ value: propValue }); + } else if (isEditing && !prevProps.isEditing) { + const textArea = this.textAreaRef.current; + + if (textArea) { + autosize(textArea); + textArea.focus(); + } + + if (getLatestValue) { + getLatestValue(); + } + } else if ( + (refreshValue || stateValue) && + refreshValue !== stateValue && + !isDisabled + ) { + // disable the component if a refresh is needed + this.setState({ isDisabled: true }); + } + } + + handleExitEditMode = (e: React.MouseEvent) => { + logClick(e, { + label: 'Cancel Editable Text', + }); + this.exitEditMode(); + }; + + exitEditMode = () => { + const { setEditMode } = this.props; + + setEditMode?.(false); + }; + + handleEnterEditMode = (e: React.MouseEvent) => { + const { setEditMode } = this.props; + + logClick(e, { + label: 'Add Editable Text', + }); + setEditMode?.(true); + }; + + handleRefreshText = (e: React.MouseEvent) => { + const { refreshValue } = this.props; + const textArea = this.textAreaRef.current; + + this.setState({ value: refreshValue, isDisabled: false }); + logClick(e, { + label: 'Refresh Editable Text', + }); + + if (textArea && refreshValue) { + textArea.value = refreshValue; + autosize.update(textArea); + } + }; + + handleUpdateText = (e: React.MouseEvent) => { + const { setEditMode, onSubmitValue } = this.props; + const newValue = this.textAreaRef.current?.value; + + const onSuccessCallback = () => { + setEditMode?.(false); + this.setState({ value: newValue }); + }; + const onFailureCallback = () => { + this.exitEditMode(); + }; + + logClick(e, { + label: 'Update Editable Text', + }); + + if (newValue) { + onSubmitValue?.(newValue, onSuccessCallback, onFailureCallback); + } + }; + + render() { + const { isEditing, editable, maxLength, allowDangerousHtml } = this.props; + const { value = '', isDisabled } = this.state; + + if (!isEditing) { + return ( +
+
+ + {value} + +
+ {editable && !value && ( + + )} +
+ ); + } + + return ( +
+