Train word embeddings from CoNLL corpus file

Train word embeddings using CoNLL corpus as input.

Depends on: CoNLL Utils

train_word_embeddings

# take a CoNLL corpus and train word/doc embeddings
import argparse
import os
import sys
from conll_utils import *
from gensim.models.word2vec import *

# random
from random import shuffle

# necessary for seeing logs
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

parser = argparse.ArgumentParser(description='Process CoNLL-U corpus and generate a word2vec model.')
# Required positional argument
parser.add_argument('input_file', type=str,
                    help='Input CoNLL-U corpus (UTF-8)')
parser.add_argument('output_file', type=str,
                    help='Base output filename of word2vec model (gensim)')
parser.add_argument('--epochs', type=int, default=10,
                    help='Training epochs (default 10). Shuffle sentences and re-train during each training epoch.')
parser.add_argument('--min-sentence-length', type=int, default=5,
                    help='If sentence is shorter than N Eojeols (full words), it will not be processed for inclusion in the word2vec model (default 5)')
parser.add_argument('--dimension', type=int, default=100,
                    help='word2vec: dimensionality of feature vectors (default 100). doc2vec mode may demand a higher value.')
parser.add_argument('--window', type=int, default=5,
                    help='word2vec: maximum distance between current and predicted word (default 5). doc2vec mode may demand a higher value.')
parser.add_argument('--workers', type=int, default=4,
                    help='word2vec: use this many worker threads to train the model (default 4)')
parser.add_argument('--min-word-occurrence', type=int, default=5,
                    help='word2vec: ignore all words with total frequency lower than this (default 5)')
parser.add_argument('--use-skipgram', action='store_true', default=True,
                    help='Use skip-gram instead of the default CBOW.')
parser.add_argument('--min-word-length', type=int, default=0,
                    help='word2vec: ignore all words with a length lower than this (default 0).')
#parser.add_argument('--char2vec', action='store_true', default=False,
#                    help='Create char2vec model (make all words their own chars).')

totalWordCount = 0
trainLabeledSentences = []

args = parser.parse_args()

trainingCorpus = ConllFile(keepMalformed=True,
                           checkParserConformity=False,
                           projectivize=False,
                           enableLemmaMorphemes=True,
                           compatibleJamo=True)

fd = open(args.input_file, 'r', encoding='utf-8')
trainingCorpus.read(fd.read())
fd.close()

for sent in trainingCorpus.sentences:
    my_sent = []
    for token in sent.tokens:
        if args.min_word_length <= 0 or (len(token.FORM) >= args.min_word_length):
            #print('add', token.morphemes)
            my_sent += [m[0]+'-'+m[1] for m in token.morphemes]
            #my_sent.append(token.FORM)
            totalWordCount += 1
    if args.min_sentence_length <= 0 or (len(my_sent) >= args.min_sentence_length):
        trainLabeledSentences.append(my_sent)

print('Beginning to build model...')
try:
    if(args.use_skipgram):
        sgFlag = 1
    else:
        sgFlag = 0
    
    model = Word2Vec(size=args.dimension, min_count=args.min_word_occurrence, window=args.window, workers=args.workers, sg=sgFlag)

    print('Building vocabulary...')
    model.build_vocab(trainLabeledSentences)

    #for epoch in range(args.epochs):
    #    print('Training epoch %d/%d...' % (epoch+1, args.epochs))
    #    # in-place shuffle of sentences
    #    # NOTE: this probably works much better without train-entire-document because each sentence can get shuffled???
    #    shuffle(trainLabeledSentences)
    model.train(trainLabeledSentences, total_examples=model.corpus_count, epochs=args.epochs)

    model.save(args.output_file)
except Exception as inst:
    print('Unexpected error:', inst)
    print('You may not have reached the minimum word occurrence count.')

 

Leave a Reply

Your email address will not be published. Required fields are marked *