001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapreduce.lib.db;
020    
021    import java.io.DataInput;
022    import java.io.DataOutput;
023    import java.io.IOException;
024    import java.sql.Connection;
025    import java.sql.DatabaseMetaData;
026    import java.sql.PreparedStatement;
027    import java.sql.ResultSet;
028    import java.sql.SQLException;
029    import java.sql.Statement;
030    import java.util.ArrayList;
031    import java.util.List;
032    
033    import org.apache.hadoop.io.LongWritable;
034    import org.apache.hadoop.io.Writable;
035    import org.apache.hadoop.mapreduce.InputFormat;
036    import org.apache.hadoop.mapreduce.InputSplit;
037    import org.apache.hadoop.mapreduce.Job;
038    import org.apache.hadoop.mapreduce.JobContext;
039    import org.apache.hadoop.mapreduce.MRJobConfig;
040    import org.apache.hadoop.mapreduce.RecordReader;
041    import org.apache.hadoop.mapreduce.TaskAttemptContext;
042    import org.apache.hadoop.util.ReflectionUtils;
043    import org.apache.hadoop.classification.InterfaceAudience;
044    import org.apache.hadoop.classification.InterfaceStability;
045    import org.apache.hadoop.conf.Configurable;
046    import org.apache.hadoop.conf.Configuration;
047    /**
048     * A InputFormat that reads input data from an SQL table.
049     * <p>
050     * DBInputFormat emits LongWritables containing the record number as 
051     * key and DBWritables as value. 
052     * 
053     * The SQL query, and input class can be using one of the two 
054     * setInput methods.
055     */
056    @InterfaceAudience.Public
057    @InterfaceStability.Stable
058    public class DBInputFormat<T extends DBWritable>
059        extends InputFormat<LongWritable, T> implements Configurable {
060    
061      private String dbProductName = "DEFAULT";
062    
063      /**
064       * A Class that does nothing, implementing DBWritable
065       */
066      @InterfaceStability.Evolving
067      public static class NullDBWritable implements DBWritable, Writable {
068        @Override
069        public void readFields(DataInput in) throws IOException { }
070        @Override
071        public void readFields(ResultSet arg0) throws SQLException { }
072        @Override
073        public void write(DataOutput out) throws IOException { }
074        @Override
075        public void write(PreparedStatement arg0) throws SQLException { }
076      }
077      
078      /**
079       * A InputSplit that spans a set of rows
080       */
081      @InterfaceStability.Evolving
082      public static class DBInputSplit extends InputSplit implements Writable {
083    
084        private long end = 0;
085        private long start = 0;
086    
087        /**
088         * Default Constructor
089         */
090        public DBInputSplit() {
091        }
092    
093        /**
094         * Convenience Constructor
095         * @param start the index of the first row to select
096         * @param end the index of the last row to select
097         */
098        public DBInputSplit(long start, long end) {
099          this.start = start;
100          this.end = end;
101        }
102    
103        /** {@inheritDoc} */
104        public String[] getLocations() throws IOException {
105          // TODO Add a layer to enable SQL "sharding" and support locality
106          return new String[] {};
107        }
108    
109        /**
110         * @return The index of the first row to select
111         */
112        public long getStart() {
113          return start;
114        }
115    
116        /**
117         * @return The index of the last row to select
118         */
119        public long getEnd() {
120          return end;
121        }
122    
123        /**
124         * @return The total row count in this split
125         */
126        public long getLength() throws IOException {
127          return end - start;
128        }
129    
130        /** {@inheritDoc} */
131        public void readFields(DataInput input) throws IOException {
132          start = input.readLong();
133          end = input.readLong();
134        }
135    
136        /** {@inheritDoc} */
137        public void write(DataOutput output) throws IOException {
138          output.writeLong(start);
139          output.writeLong(end);
140        }
141      }
142    
143      private String conditions;
144    
145      private Connection connection;
146    
147      private String tableName;
148    
149      private String[] fieldNames;
150    
151      private DBConfiguration dbConf;
152    
153      /** {@inheritDoc} */
154      public void setConf(Configuration conf) {
155    
156        dbConf = new DBConfiguration(conf);
157    
158        try {
159          getConnection();
160    
161          DatabaseMetaData dbMeta = connection.getMetaData();
162          this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
163        }
164        catch (Exception ex) {
165          throw new RuntimeException(ex);
166        }
167    
168        tableName = dbConf.getInputTableName();
169        fieldNames = dbConf.getInputFieldNames();
170        conditions = dbConf.getInputConditions();
171      }
172    
173      public Configuration getConf() {
174        return dbConf.getConf();
175      }
176      
177      public DBConfiguration getDBConf() {
178        return dbConf;
179      }
180    
181      public Connection getConnection() {
182        try {
183          if (null == this.connection) {
184            // The connection was closed; reinstantiate it.
185            this.connection = dbConf.getConnection();
186            this.connection.setAutoCommit(false);
187            this.connection.setTransactionIsolation(
188                Connection.TRANSACTION_SERIALIZABLE);
189          }
190        } catch (Exception e) {
191          throw new RuntimeException(e);
192        }
193        return connection;
194      }
195    
196      public String getDBProductName() {
197        return dbProductName;
198      }
199    
200      protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
201          Configuration conf) throws IOException {
202    
203        @SuppressWarnings("unchecked")
204        Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
205        try {
206          // use database product name to determine appropriate record reader.
207          if (dbProductName.startsWith("ORACLE")) {
208            // use Oracle-specific db reader.
209            return new OracleDBRecordReader<T>(split, inputClass,
210                conf, getConnection(), getDBConf(), conditions, fieldNames,
211                tableName);
212          } else if (dbProductName.startsWith("MYSQL")) {
213            // use MySQL-specific db reader.
214            return new MySQLDBRecordReader<T>(split, inputClass,
215                conf, getConnection(), getDBConf(), conditions, fieldNames,
216                tableName);
217          } else {
218            // Generic reader.
219            return new DBRecordReader<T>(split, inputClass,
220                conf, getConnection(), getDBConf(), conditions, fieldNames,
221                tableName);
222          }
223        } catch (SQLException ex) {
224          throw new IOException(ex.getMessage());
225        }
226      }
227    
228      /** {@inheritDoc} */
229      @SuppressWarnings("unchecked")
230      public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
231          TaskAttemptContext context) throws IOException, InterruptedException {  
232    
233        return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
234      }
235    
236      /** {@inheritDoc} */
237      public List<InputSplit> getSplits(JobContext job) throws IOException {
238    
239        ResultSet results = null;  
240        Statement statement = null;
241        try {
242          statement = connection.createStatement();
243    
244          results = statement.executeQuery(getCountQuery());
245          results.next();
246    
247          long count = results.getLong(1);
248          int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
249          long chunkSize = (count / chunks);
250    
251          results.close();
252          statement.close();
253    
254          List<InputSplit> splits = new ArrayList<InputSplit>();
255    
256          // Split the rows into n-number of chunks and adjust the last chunk
257          // accordingly
258          for (int i = 0; i < chunks; i++) {
259            DBInputSplit split;
260    
261            if ((i + 1) == chunks)
262              split = new DBInputSplit(i * chunkSize, count);
263            else
264              split = new DBInputSplit(i * chunkSize, (i * chunkSize)
265                  + chunkSize);
266    
267            splits.add(split);
268          }
269    
270          connection.commit();
271          return splits;
272        } catch (SQLException e) {
273          throw new IOException("Got SQLException", e);
274        } finally {
275          try {
276            if (results != null) { results.close(); }
277          } catch (SQLException e1) {}
278          try {
279            if (statement != null) { statement.close(); }
280          } catch (SQLException e1) {}
281    
282          closeConnection();
283        }
284      }
285    
286      /** Returns the query for getting the total number of rows, 
287       * subclasses can override this for custom behaviour.*/
288      protected String getCountQuery() {
289        
290        if(dbConf.getInputCountQuery() != null) {
291          return dbConf.getInputCountQuery();
292        }
293        
294        StringBuilder query = new StringBuilder();
295        query.append("SELECT COUNT(*) FROM " + tableName);
296    
297        if (conditions != null && conditions.length() > 0)
298          query.append(" WHERE " + conditions);
299        return query.toString();
300      }
301    
302      /**
303       * Initializes the map-part of the job with the appropriate input settings.
304       * 
305       * @param job The map-reduce job
306       * @param inputClass the class object implementing DBWritable, which is the 
307       * Java object holding tuple fields.
308       * @param tableName The table to read data from
309       * @param conditions The condition which to select data with, 
310       * eg. '(updated > 20070101 AND length > 0)'
311       * @param orderBy the fieldNames in the orderBy clause.
312       * @param fieldNames The field names in the table
313       * @see #setInput(Job, Class, String, String)
314       */
315      public static void setInput(Job job, 
316          Class<? extends DBWritable> inputClass,
317          String tableName,String conditions, 
318          String orderBy, String... fieldNames) {
319        job.setInputFormatClass(DBInputFormat.class);
320        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
321        dbConf.setInputClass(inputClass);
322        dbConf.setInputTableName(tableName);
323        dbConf.setInputFieldNames(fieldNames);
324        dbConf.setInputConditions(conditions);
325        dbConf.setInputOrderBy(orderBy);
326      }
327      
328      /**
329       * Initializes the map-part of the job with the appropriate input settings.
330       * 
331       * @param job The map-reduce job
332       * @param inputClass the class object implementing DBWritable, which is the 
333       * Java object holding tuple fields.
334       * @param inputQuery the input query to select fields. Example : 
335       * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
336       * @param inputCountQuery the input query that returns 
337       * the number of records in the table. 
338       * Example : "SELECT COUNT(f1) FROM Mytable"
339       * @see #setInput(Job, Class, String, String, String, String...)
340       */
341      public static void setInput(Job job,
342          Class<? extends DBWritable> inputClass,
343          String inputQuery, String inputCountQuery) {
344        job.setInputFormatClass(DBInputFormat.class);
345        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
346        dbConf.setInputClass(inputClass);
347        dbConf.setInputQuery(inputQuery);
348        dbConf.setInputCountQuery(inputCountQuery);
349      }
350    
351      protected void closeConnection() {
352        try {
353          if (null != this.connection) {
354            this.connection.close();
355            this.connection = null;
356          }
357        } catch (SQLException sqlE) { } // ignore exception on close.
358      }
359    }